Unverified Commit 451906e4 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Release MobileDet code and model, and require tf_slim installation for OD API. (#8562)

* Merged commit includes the following changes:
311933687  by Sergio Guadarrama:

    Removes spurios use of tf.compat.v2, which results in spurious tf.compat.v1.compat.v2. Adds basic test to nasnet_utils.
    Replaces all remaining import tensorflow as tf with import tensorflow.compat.v1 as tf

--
311766063  by Sergio Guadarrama:

    Removes explicit tf.compat.v1 in all call sites (we already import tf.compat.v1, so this code was  doing tf.compat.v1.compat.v1). The existing code worked in latest version of tensorflow, 2.2, (and 1.15) but not in 1.14 or in 2.0.0a, this CL fixes it.

--
311624958  by Sergio Guadarrama:

    Updates README that doesn't render properly in github documentation

--
310980959  by Sergio Guadarrama:

    Moves research_models/slim off tf.contrib.slim/layers/framework to tf_slim

--
310263156  by Sergio Guadarrama:

    Adds model breakdown for MobilenetV3

--
308640...
parent 73b5be67
...@@ -20,7 +20,7 @@ described in: ...@@ -20,7 +20,7 @@ described in:
T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar
""" """
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.anchor_generators import grid_anchor_generator from object_detection.anchor_generators import grid_anchor_generator
from object_detection.core import anchor_generator from object_detection.core import anchor_generator
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Tests for anchor_generators.multiscale_grid_anchor_generator_test.py.""" """Tests for anchor_generators.multiscale_grid_anchor_generator_test.py."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.anchor_generators import multiscale_grid_anchor_generator as mg from object_detection.anchor_generators import multiscale_grid_anchor_generator as mg
from object_detection.utils import test_case from object_detection.utils import test_case
......
...@@ -28,7 +28,7 @@ Faster RCNN box coder follows the coding schema described below: ...@@ -28,7 +28,7 @@ Faster RCNN box coder follows the coding schema described below:
See http://arxiv.org/abs/1506.01497 for details. See http://arxiv.org/abs/1506.01497 for details.
""" """
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.core import box_coder from object_detection.core import box_coder
from object_detection.core import box_list from object_detection.core import box_list
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Tests for object_detection.box_coder.faster_rcnn_box_coder.""" """Tests for object_detection.box_coder.faster_rcnn_box_coder."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.box_coders import faster_rcnn_box_coder from object_detection.box_coders import faster_rcnn_box_coder
from object_detection.core import box_list from object_detection.core import box_list
......
...@@ -35,7 +35,7 @@ to box coordinates): ...@@ -35,7 +35,7 @@ to box coordinates):
anchor-encoded keypoint coordinates. anchor-encoded keypoint coordinates.
""" """
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.core import box_coder from object_detection.core import box_coder
from object_detection.core import box_list from object_detection.core import box_list
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Tests for object_detection.box_coder.keypoint_box_coder.""" """Tests for object_detection.box_coder.keypoint_box_coder."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.box_coders import keypoint_box_coder from object_detection.box_coders import keypoint_box_coder
from object_detection.core import box_list from object_detection.core import box_list
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Tests for object_detection.box_coder.mean_stddev_boxcoder.""" """Tests for object_detection.box_coder.mean_stddev_boxcoder."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.box_coders import mean_stddev_box_coder from object_detection.box_coders import mean_stddev_box_coder
from object_detection.core import box_list from object_detection.core import box_list
......
...@@ -32,7 +32,7 @@ coder when the objects being detected tend to be square (e.g. faces) and when ...@@ -32,7 +32,7 @@ coder when the objects being detected tend to be square (e.g. faces) and when
the input images are not distorted via resizing. the input images are not distorted via resizing.
""" """
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.core import box_coder from object_detection.core import box_coder
from object_detection.core import box_list from object_detection.core import box_list
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Tests for object_detection.box_coder.square_box_coder.""" """Tests for object_detection.box_coder.square_box_coder."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.box_coders import square_box_coder from object_detection.box_coders import square_box_coder
from object_detection.core import box_list from object_detection.core import box_list
......
...@@ -24,7 +24,7 @@ import math ...@@ -24,7 +24,7 @@ import math
from six.moves import range from six.moves import range
from six.moves import zip from six.moves import zip
import tensorflow as tf import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.anchor_generators import flexible_grid_anchor_generator from object_detection.anchor_generators import flexible_grid_anchor_generator
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Tests for box_coder_builder.""" """Tests for box_coder_builder."""
import tensorflow as tf import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.box_coders import faster_rcnn_box_coder from object_detection.box_coders import faster_rcnn_box_coder
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"""Function to build box predictor from configuration.""" """Function to build box predictor from configuration."""
import collections import collections
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.predictors import convolutional_box_predictor from object_detection.predictors import convolutional_box_predictor
from object_detection.predictors import convolutional_keras_box_predictor from object_detection.predictors import convolutional_keras_box_predictor
from object_detection.predictors import mask_rcnn_box_predictor from object_detection.predictors import mask_rcnn_box_predictor
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Tests for box_predictor_builder.""" """Tests for box_predictor_builder."""
import mock import mock
import tensorflow as tf import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.builders import box_predictor_builder from object_detection.builders import box_predictor_builder
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Tensorflow ops to calibrate class predictions and background class.""" """Tensorflow ops to calibrate class predictions and background class."""
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.utils import shape_utils from object_detection.utils import shape_utils
......
...@@ -22,7 +22,7 @@ from __future__ import print_function ...@@ -22,7 +22,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from scipy import interpolate from scipy import interpolate
from six.moves import zip from six.moves import zip
import tensorflow as tf import tensorflow.compat.v1 as tf
from object_detection.builders import calibration_builder from object_detection.builders import calibration_builder
from object_detection.protos import calibration_pb2 from object_detection.protos import calibration_pb2
......
...@@ -27,7 +27,7 @@ from __future__ import division ...@@ -27,7 +27,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools import functools
import tensorflow as tf import tensorflow.compat.v1 as tf
from tensorflow.contrib import data as tf_data from tensorflow.contrib import data as tf_data
from object_detection.builders import decoder_builder from object_detection.builders import decoder_builder
...@@ -118,7 +118,7 @@ def shard_function_for_context(input_context): ...@@ -118,7 +118,7 @@ def shard_function_for_context(input_context):
def build(input_reader_config, batch_size=None, transform_input_data_fn=None, def build(input_reader_config, batch_size=None, transform_input_data_fn=None,
input_context=None): input_context=None, reduce_to_frame_fn=None):
"""Builds a tf.data.Dataset. """Builds a tf.data.Dataset.
Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
...@@ -132,6 +132,8 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None, ...@@ -132,6 +132,8 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None,
input_context: optional, A tf.distribute.InputContext object used to input_context: optional, A tf.distribute.InputContext object used to
shard filenames and compute per-replica batch_size when this function shard filenames and compute per-replica batch_size when this function
is being called per-replica. is being called per-replica.
reduce_to_frame_fn: Function that extracts frames from tf.SequenceExample
type input data.
Returns: Returns:
A tf.data.Dataset based on the input_reader_config. A tf.data.Dataset based on the input_reader_config.
...@@ -151,18 +153,9 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None, ...@@ -151,18 +153,9 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None,
if not config.input_path: if not config.input_path:
raise ValueError('At least one input path must be specified in ' raise ValueError('At least one input path must be specified in '
'`input_reader_config`.') '`input_reader_config`.')
def process_fn(value):
"""Sets up tf graph that decodes, transforms and pads input data."""
processed_tensors = decoder.decode(value)
if transform_input_data_fn is not None:
processed_tensors = transform_input_data_fn(processed_tensors)
return processed_tensors
shard_fn = shard_function_for_context(input_context) shard_fn = shard_function_for_context(input_context)
if input_context is not None: if input_context is not None:
batch_size = input_context.get_per_replica_batch_size(batch_size) batch_size = input_context.get_per_replica_batch_size(batch_size)
dataset = read_dataset( dataset = read_dataset(
functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000), functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
config.input_path[:], input_reader_config, filename_shard_fn=shard_fn) config.input_path[:], input_reader_config, filename_shard_fn=shard_fn)
...@@ -170,16 +163,12 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None, ...@@ -170,16 +163,12 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None,
dataset = dataset.shard(input_reader_config.sample_1_of_n_examples, 0) dataset = dataset.shard(input_reader_config.sample_1_of_n_examples, 0)
# TODO(rathodv): make batch size a required argument once the old binaries # TODO(rathodv): make batch size a required argument once the old binaries
# are deleted. # are deleted.
if batch_size: dataset = dataset.map(decoder.decode, tf.data.experimental.AUTOTUNE)
num_parallel_calls = batch_size * input_reader_config.num_parallel_batches if reduce_to_frame_fn:
else: dataset = reduce_to_frame_fn(dataset)
num_parallel_calls = input_reader_config.num_parallel_map_calls if transform_input_data_fn is not None:
# TODO(b/123952794): Migrate to V2 function. dataset = dataset.map(transform_input_data_fn,
if hasattr(dataset, 'map_with_legacy_function'): tf.data.experimental.AUTOTUNE)
data_map_fn = dataset.map_with_legacy_function
else:
data_map_fn = dataset.map
dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls)
if batch_size: if batch_size:
dataset = dataset.apply( dataset = dataset.apply(
tf_data.batch_and_drop_remainder(batch_size)) tf_data.batch_and_drop_remainder(batch_size))
......
...@@ -22,17 +22,17 @@ from __future__ import print_function ...@@ -22,17 +22,17 @@ from __future__ import print_function
import os import os
import numpy as np import numpy as np
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.builders import dataset_builder from object_detection.builders import dataset_builder
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.dataset_tools import seq_example_util
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
from object_detection.utils import dataset_util from object_detection.utils import dataset_util
from object_detection.utils import test_case from object_detection.utils import test_case
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
try: try:
from tensorflow.contrib import lookup as contrib_lookup from tensorflow.contrib import lookup as contrib_lookup
...@@ -43,15 +43,17 @@ except ImportError: ...@@ -43,15 +43,17 @@ except ImportError:
def get_iterator_next_for_testing(dataset, is_tf2): def get_iterator_next_for_testing(dataset, is_tf2):
iterator = dataset.make_initializable_iterator()
if not is_tf2:
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
return iterator.get_next()
# In TF2, lookup tables are not supported in one shot iterators, but def _get_labelmap_path():
# initialization is implicit. """Returns an absolute path to label map file."""
if is_tf2: parent_path = os.path.dirname(tf.resource_loader.get_data_files_path())
return dataset.make_initializable_iterator().get_next() return os.path.join(parent_path, 'data',
# In TF1, we use one shot iterator because it does not require running 'pet_label_map.pbtxt')
# a separate init op.
else:
return dataset.make_one_shot_iterator().get_next()
class DatasetBuilderTest(test_case.TestCase): class DatasetBuilderTest(test_case.TestCase):
...@@ -111,6 +113,57 @@ class DatasetBuilderTest(test_case.TestCase): ...@@ -111,6 +113,57 @@ class DatasetBuilderTest(test_case.TestCase):
return os.path.join(self.get_temp_dir(), '?????.tfrecord') return os.path.join(self.get_temp_dir(), '?????.tfrecord')
def _make_random_serialized_jpeg_images(self, num_frames, image_height,
image_width):
def graph_fn():
images = tf.cast(tf.random.uniform(
[num_frames, image_height, image_width, 3],
maxval=256,
dtype=tf.int32), dtype=tf.uint8)
images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
return encoded_images_list
encoded_images = self.execute(graph_fn, [])
return encoded_images
def create_tf_record_sequence_example(self):
path = os.path.join(self.get_temp_dir(), 'seq_tfrecord')
writer = tf.python_io.TFRecordWriter(path)
num_frames = 4
image_height = 4
image_width = 5
image_source_ids = [str(i) for i in range(num_frames)]
with self.test_session():
encoded_images = self._make_random_serialized_jpeg_images(
num_frames, image_height, image_width)
sequence_example_serialized = seq_example_util.make_sequence_example(
dataset_name='video_dataset',
video_id='video',
encoded_images=encoded_images,
image_height=image_height,
image_width=image_width,
image_source_ids=image_source_ids,
image_format='JPEG',
is_annotated=[[1], [1], [1], [1]],
bboxes=[
[[]], # Frame 0.
[[0., 0., 1., 1.]], # Frame 1.
[[0., 0., 1., 1.],
[0.1, 0.1, 0.2, 0.2]], # Frame 2.
[[]], # Frame 3.
],
label_strings=[
[], # Frame 0.
['Abyssinian'], # Frame 1.
['Abyssinian', 'american_bulldog'], # Frame 2.
[], # Frame 3
]).SerializeToString()
writer.write(sequence_example_serialized)
writer.close()
return path
def test_build_tf_record_input_reader(self): def test_build_tf_record_input_reader(self):
tf_record_path = self.create_tf_record() tf_record_path = self.create_tf_record()
...@@ -143,6 +196,71 @@ class DatasetBuilderTest(test_case.TestCase): ...@@ -143,6 +196,71 @@ class DatasetBuilderTest(test_case.TestCase):
[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0][0]) output_dict[fields.InputDataFields.groundtruth_boxes][0][0])
def get_mock_reduce_to_frame_fn(self):
def mock_reduce_to_frame_fn(dataset):
def get_frame(tensor_dict):
out_tensor_dict = {}
out_tensor_dict[fields.InputDataFields.source_id] = (
tensor_dict[fields.InputDataFields.source_id][0])
return out_tensor_dict
return dataset.map(get_frame, tf.data.experimental.AUTOTUNE)
return mock_reduce_to_frame_fn
def test_build_tf_record_input_reader_sequence_example_train(self):
tf_record_path = self.create_tf_record_sequence_example()
label_map_path = _get_labelmap_path()
input_type = 'TF_SEQUENCE_EXAMPLE'
input_reader_text_proto = """
shuffle: false
num_readers: 1
input_type: {1}
tf_record_input_reader {{
input_path: '{0}'
}}
""".format(tf_record_path, input_type)
input_reader_proto = input_reader_pb2.InputReader()
input_reader_proto.label_map_path = label_map_path
text_format.Merge(input_reader_text_proto, input_reader_proto)
reduce_to_frame_fn = self.get_mock_reduce_to_frame_fn()
def graph_fn():
return get_iterator_next_for_testing(
dataset_builder.build(input_reader_proto, batch_size=1,
reduce_to_frame_fn=reduce_to_frame_fn),
self.is_tf2())
output_dict = self.execute(graph_fn, [])
self.assertEqual((1,),
output_dict[fields.InputDataFields.source_id].shape)
def test_build_tf_record_input_reader_sequence_example_test(self):
tf_record_path = self.create_tf_record_sequence_example()
input_type = 'TF_SEQUENCE_EXAMPLE'
label_map_path = _get_labelmap_path()
input_reader_text_proto = """
shuffle: false
num_readers: 1
input_type: {1}
tf_record_input_reader {{
input_path: '{0}'
}}
""".format(tf_record_path, input_type)
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(input_reader_text_proto, input_reader_proto)
input_reader_proto.label_map_path = label_map_path
reduce_to_frame_fn = self.get_mock_reduce_to_frame_fn()
def graph_fn():
return get_iterator_next_for_testing(
dataset_builder.build(input_reader_proto, batch_size=1,
reduce_to_frame_fn=reduce_to_frame_fn),
self.is_tf2())
output_dict = self.execute(graph_fn, [])
self.assertEqual((1,),
output_dict[fields.InputDataFields.source_id].shape)
def test_build_tf_record_input_reader_and_load_instance_masks(self): def test_build_tf_record_input_reader_and_load_instance_masks(self):
tf_record_path = self.create_tf_record() tf_record_path = self.create_tf_record()
......
...@@ -23,6 +23,7 @@ from __future__ import division ...@@ -23,6 +23,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from object_detection.data_decoders import tf_example_decoder from object_detection.data_decoders import tf_example_decoder
from object_detection.data_decoders import tf_sequence_example_decoder
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
...@@ -46,16 +47,24 @@ def build(input_reader_config): ...@@ -46,16 +47,24 @@ def build(input_reader_config):
label_map_proto_file = None label_map_proto_file = None
if input_reader_config.HasField('label_map_path'): if input_reader_config.HasField('label_map_path'):
label_map_proto_file = input_reader_config.label_map_path label_map_proto_file = input_reader_config.label_map_path
decoder = tf_example_decoder.TfExampleDecoder( input_type = input_reader_config.input_type
load_instance_masks=input_reader_config.load_instance_masks, if input_type == input_reader_pb2.InputType.TF_EXAMPLE:
load_multiclass_scores=input_reader_config.load_multiclass_scores, decoder = tf_example_decoder.TfExampleDecoder(
load_context_features=input_reader_config.load_context_features, load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type, load_multiclass_scores=input_reader_config.load_multiclass_scores,
label_map_proto_file=label_map_proto_file, load_context_features=input_reader_config.load_context_features,
use_display_name=input_reader_config.use_display_name, instance_mask_type=input_reader_config.mask_type,
num_additional_channels=input_reader_config.num_additional_channels, label_map_proto_file=label_map_proto_file,
num_keypoints=input_reader_config.num_keypoints) use_display_name=input_reader_config.use_display_name,
num_additional_channels=input_reader_config.num_additional_channels,
return decoder num_keypoints=input_reader_config.num_keypoints,
expand_hierarchy_labels=input_reader_config.expand_labels_hierarchy)
return decoder
elif input_type == input_reader_pb2.InputType.TF_SEQUENCE_EXAMPLE:
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features)
return decoder
raise ValueError('Unsupported input_type in config.')
raise ValueError('Unsupported input_reader_config.') raise ValueError('Unsupported input_reader_config.')
...@@ -19,16 +19,25 @@ from __future__ import absolute_import ...@@ -19,16 +19,25 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import numpy as np import numpy as np
import tensorflow as tf import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.builders import decoder_builder from object_detection.builders import decoder_builder
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.dataset_tools import seq_example_util
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
from object_detection.utils import dataset_util from object_detection.utils import dataset_util
def _get_labelmap_path():
"""Returns an absolute path to label map file."""
parent_path = os.path.dirname(tf.resource_loader.get_data_files_path())
return os.path.join(parent_path, 'data',
'pet_label_map.pbtxt')
class DecoderBuilderTest(tf.test.TestCase): class DecoderBuilderTest(tf.test.TestCase):
def _make_serialized_tf_example(self, has_additional_channels=False): def _make_serialized_tf_example(self, has_additional_channels=False):
...@@ -60,6 +69,50 @@ class DecoderBuilderTest(tf.test.TestCase): ...@@ -60,6 +69,50 @@ class DecoderBuilderTest(tf.test.TestCase):
example = tf.train.Example(features=tf.train.Features(feature=features)) example = tf.train.Example(features=tf.train.Features(feature=features))
return example.SerializeToString() return example.SerializeToString()
def _make_random_serialized_jpeg_images(self, num_frames, image_height,
image_width):
images = tf.cast(tf.random.uniform(
[num_frames, image_height, image_width, 3],
maxval=256,
dtype=tf.int32), dtype=tf.uint8)
images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
with tf.Session() as sess:
encoded_images = sess.run(encoded_images_list)
return encoded_images
def _make_serialized_tf_sequence_example(self):
num_frames = 4
image_height = 20
image_width = 30
image_source_ids = [str(i) for i in range(num_frames)]
with self.test_session():
encoded_images = self._make_random_serialized_jpeg_images(
num_frames, image_height, image_width)
sequence_example_serialized = seq_example_util.make_sequence_example(
dataset_name='video_dataset',
video_id='video',
encoded_images=encoded_images,
image_height=image_height,
image_width=image_width,
image_source_ids=image_source_ids,
image_format='JPEG',
is_annotated=[[1], [1], [1], [1]],
bboxes=[
[[]], # Frame 0.
[[0., 0., 1., 1.]], # Frame 1.
[[0., 0., 1., 1.],
[0.1, 0.1, 0.2, 0.2]], # Frame 2.
[[]], # Frame 3.
],
label_strings=[
[], # Frame 0.
['Abyssinian'], # Frame 1.
['Abyssinian', 'american_bulldog'], # Frame 2.
[], # Frame 3
]).SerializeToString()
return sequence_example_serialized
def test_build_tf_record_input_reader(self): def test_build_tf_record_input_reader(self):
input_reader_text_proto = 'tf_record_input_reader {}' input_reader_text_proto = 'tf_record_input_reader {}'
input_reader_proto = input_reader_pb2.InputReader() input_reader_proto = input_reader_pb2.InputReader()
...@@ -82,6 +135,43 @@ class DecoderBuilderTest(tf.test.TestCase): ...@@ -82,6 +135,43 @@ class DecoderBuilderTest(tf.test.TestCase):
[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0]) output_dict[fields.InputDataFields.groundtruth_boxes][0])
def test_build_tf_record_input_reader_sequence_example(self):
label_map_path = _get_labelmap_path()
input_reader_text_proto = """
input_type: TF_SEQUENCE_EXAMPLE
tf_record_input_reader {}
"""
input_reader_proto = input_reader_pb2.InputReader()
input_reader_proto.label_map_path = label_map_path
text_format.Parse(input_reader_text_proto, input_reader_proto)
decoder = decoder_builder.build(input_reader_proto)
tensor_dict = decoder.decode(self._make_serialized_tf_sequence_example())
with tf.train.MonitoredSession() as sess:
output_dict = sess.run(tensor_dict)
expected_groundtruth_classes = [[-1, -1], [1, -1], [1, 2], [-1, -1]]
expected_groundtruth_boxes = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 1.0, 1.0], [0.1, 0.1, 0.2, 0.2]],
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]
expected_num_groundtruth_boxes = [0, 1, 2, 0]
self.assertNotIn(
fields.InputDataFields.groundtruth_instance_masks, output_dict)
# Sequence example images are encoded.
self.assertEqual((4,), output_dict[fields.InputDataFields.image].shape)
self.assertAllEqual(expected_groundtruth_classes,
output_dict[fields.InputDataFields.groundtruth_classes])
self.assertEqual(
(4, 2, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllClose(expected_groundtruth_boxes,
output_dict[fields.InputDataFields.groundtruth_boxes])
self.assertAllClose(
expected_num_groundtruth_boxes,
output_dict[fields.InputDataFields.num_groundtruth_boxes])
def test_build_tf_record_input_reader_and_load_instance_masks(self): def test_build_tf_record_input_reader_and_load_instance_masks(self):
input_reader_text_proto = """ input_reader_text_proto = """
load_instance_masks: true load_instance_masks: true
......
...@@ -14,11 +14,10 @@ ...@@ -14,11 +14,10 @@
# ============================================================================== # ==============================================================================
"""Functions for quantized training and evaluation.""" """Functions for quantized training and evaluation."""
import tensorflow as tf import tensorflow.compat.v1 as tf
import tf_slim as slim
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
try: try:
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import quantize as contrib_quantize from tensorflow.contrib import quantize as contrib_quantize
except ImportError: except ImportError:
# TF 2.0 doesn't ship with contrib. # TF 2.0 doesn't ship with contrib.
...@@ -49,7 +48,6 @@ def build(graph_rewriter_config, is_training): ...@@ -49,7 +48,6 @@ def build(graph_rewriter_config, is_training):
contrib_quantize.experimental_create_eval_graph( contrib_quantize.experimental_create_eval_graph(
input_graph=tf.get_default_graph() input_graph=tf.get_default_graph()
) )
slim.summarize_collection('quant_vars')
contrib_layers.summarize_collection('quant_vars')
return graph_rewrite_fn return graph_rewrite_fn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment