Commit 9fce9c64 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Merged commit includes the following changes:

199348852  by Zhichao Lu:

    Small typos fixes in VRD evaluation.

--
199315191  by Zhichao Lu:

    Change padding shapes when additional channels are available.

--
199309180  by Zhichao Lu:

    Adds minor fixes to the Object Detection API implementation.

--
199298605  by Zhichao Lu:

    Force num_readers to be 1 when only input file is not sharded.

--
199292952  by Zhichao Lu:

    Adds image-level labels parsing into TfExampleDetectionAndGTParser.

--
199259866  by Zhichao Lu:

    Visual Relationships Evaluation executable.

--
199208330  by Zhichao Lu:

    Infer train_config.batch_size as the effective batch size. Therefore we need to divide the effective batch size in trainer by train_config.replica_to_aggregate to get per worker batch size.

--
199207842  by Zhichao Lu:

    Internal change.

--
199204222  by Zhichao Lu:

    In case the image has more than three channels, we only take the first three channels for visualization.

--
199194388  by Zhichao Lu:

    Correcting protocols description: VOC 2007 -> VOC 2012.

--
199188290  by Zhichao Lu:

    Adds per-relationship APs and mAP computation to VRD evaluation.

--
199158801  by Zhichao Lu:

    If available, additional channels are merged with input image.

--
199099637  by Zhichao Lu:

    OpenImages Challenge metric support:
    -adding verified labels standard field for TFExample;
    -adding tfrecord creation functionality.

--
198957391  by Zhichao Lu:

    Allow tf record sharding when creating pets dataset.

--
198925184  by Zhichao Lu:

    Introduce moving average support for evaluation. Also adding the ability to override this configuration via config_util.

--
198918186  by Zhichao Lu:

    Handles the case where there are 0 box masks.

--
198809009  by Zhichao Lu:

    Plumb groundtruth weights into target assigner for Faster RCNN.

--
198759987  by Zhichao Lu:

    Fix object detection test broken by shape inference.

--
198668602  by Zhichao Lu:

    Adding a new input field in data_decoders/tf_example_decoder.py for storing additional channels.

--
198530013  by Zhichao Lu:

    An util for hierarchical expandion of boxes and labels of OID dataset.

--
198503124  by Zhichao Lu:

    Fix dimension mismatch error introduced by
    https://github.com/tensorflow/tensorflow/pull/18251, or cl/194031845.
    After above change, conv2d strictly checks for conv_dims + 2 == input_rank.

--
198445807  by Zhichao Lu:

    Enabling Object Detection Challenge 2018 metric in evaluator.py framework for
    running eval job.
    Renaming old OpenImages V2 metric.

--
198413950  by Zhichao Lu:

    Support generic configuration override using namespaced keys

    Useful for adding custom hyper-parameter tuning fields without having to add custom override methods to config_utils.py.

--
198106437  by Zhichao Lu:

    Enable fused batchnorm now that quantization is supported.

--
198048364  by Zhichao Lu:

    Add support for keypoints in tf sequence examples and some util ops.

--
198004736  by Zhichao Lu:

    Relax postprocessing unit tests that are based on assumption that tf.image.non_max_suppression are stable with respect to input.

--
197997513  by Zhichao Lu:

    More lenient validation for normalized box boundaries.

--
197940068  by Zhichao Lu:

    A couple of minor updates/fixes:
    - Updating input reader proto with option to use display_name when decoding data.
    - Updating visualization tool to specify whether using absolute or normalized box coordinates. Appropriate boxes will now appear in TB when using model_main.py

--
197920152  by Zhichao Lu:

    Add quantized training support in the new OD binaries and a config for SSD Mobilenet v1 quantized training that is TPU compatible.

--
197213563  by Zhichao Lu:

    Do not share batch_norm for classification and regression tower in weight shared box predictor.

--
197196757  by Zhichao Lu:

    Relax the box_predictor api to return box_prediction of shape [batch_size, num_anchors, code_size] in addition to [batch_size, num_anchors, (1|q), code_size].

--
196898361  by Zhichao Lu:

    Allow per-channel scalar value to pad input image with when using keep aspect ratio resizer (when pad_to_max_dimension=True).

    In Object Detection Pipeline, we pad image before normalization and this skews batch_norm statistics during training. The option to set per channel pad value lets us truly pad with zeros.

--
196592101  by Zhichao Lu:

    Fix bug regarding tfrecord shuffling in object_detection

--
196320138  by Zhichao Lu:

    Fix typo in exporting_models.md

--

PiperOrigin-RevId: 199348852
parent ed901b73
......@@ -56,15 +56,26 @@ def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
else:
height, width = spatial_image_shape # pylint: disable=unpacking-non-sequence
num_additional_channels = 0
if fields.InputDataFields.image_additional_channels in dataset.output_shapes:
num_additional_channels = dataset.output_shapes[
fields.InputDataFields.image_additional_channels].dims[2].value
padding_shapes = {
fields.InputDataFields.image: [height, width, 3],
# Additional channels are merged before batching.
fields.InputDataFields.image: [
height, width, 3 + num_additional_channels
],
fields.InputDataFields.image_additional_channels: [
height, width, num_additional_channels
],
fields.InputDataFields.source_id: [],
fields.InputDataFields.filename: [],
fields.InputDataFields.key: [],
fields.InputDataFields.groundtruth_difficult: [max_num_boxes],
fields.InputDataFields.groundtruth_boxes: [max_num_boxes, 4],
fields.InputDataFields.groundtruth_instance_masks: [max_num_boxes, height,
width],
fields.InputDataFields.groundtruth_instance_masks: [
max_num_boxes, height, width
],
fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
fields.InputDataFields.groundtruth_area: [max_num_boxes],
......@@ -74,7 +85,8 @@ def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
fields.InputDataFields.groundtruth_label_scores: [max_num_boxes],
fields.InputDataFields.true_image_shape: [3],
fields.InputDataFields.multiclass_scores: [
max_num_boxes, num_classes + 1 if num_classes is not None else None],
max_num_boxes, num_classes + 1 if num_classes is not None else None
],
}
# Determine whether groundtruth_classes are integers or one-hot encodings, and
# apply batching appropriately.
......@@ -90,7 +102,9 @@ def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
'rank 2 tensor (one-hot encodings)')
if fields.InputDataFields.original_image in dataset.output_shapes:
padding_shapes[fields.InputDataFields.original_image] = [None, None, 3]
padding_shapes[fields.InputDataFields.original_image] = [
None, None, 3 + num_additional_channels
]
if fields.InputDataFields.groundtruth_keypoints in dataset.output_shapes:
tensor_shape = dataset.output_shapes[fields.InputDataFields.
groundtruth_keypoints]
......@@ -108,9 +122,13 @@ def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
for tensor_key, _ in dataset.output_shapes.items()}
def build(input_reader_config, transform_input_data_fn=None,
batch_size=None, max_num_boxes=None, num_classes=None,
spatial_image_shape=None):
def build(input_reader_config,
transform_input_data_fn=None,
batch_size=None,
max_num_boxes=None,
num_classes=None,
spatial_image_shape=None,
num_additional_channels=0):
"""Builds a tf.data.Dataset.
Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
......@@ -128,6 +146,7 @@ def build(input_reader_config, transform_input_data_fn=None,
spatial_image_shape: A list of two integers of the form [height, width]
containing expected spatial shape of the image after applying
transform_input_data_fn. If None, will use dynamic shapes.
num_additional_channels: Number of additional channels to use in the input.
Returns:
A tf.data.Dataset based on the input_reader_config.
......@@ -152,7 +171,9 @@ def build(input_reader_config, transform_input_data_fn=None,
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file)
label_map_proto_file=label_map_proto_file,
use_display_name=input_reader_config.use_display_name,
num_additional_channels=num_additional_channels)
def process_fn(value):
processed = decoder.decode(value)
......
......@@ -30,49 +30,50 @@ from object_detection.utils import dataset_util
class DatasetBuilderTest(tf.test.TestCase):
def create_tf_record(self):
def create_tf_record(self, has_additional_channels=False):
path = os.path.join(self.get_temp_dir(), 'tfrecord')
writer = tf.python_io.TFRecordWriter(path)
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
additional_channels_tensor = np.random.randint(
255, size=(4, 5, 1)).astype(np.uint8)
flat_mask = (4 * 5) * [1.0]
with self.test_session():
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
encoded_additional_channels_jpeg = tf.image.encode_jpeg(
tf.constant(additional_channels_tensor)).eval()
features = {
'image/encoded':
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(value=[encoded_jpeg])),
'image/format':
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(value=['jpeg'.encode('utf-8')])
),
'image/height':
feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=[4])),
'image/width':
feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=[5])),
'image/object/bbox/xmin':
feature_pb2.Feature(float_list=feature_pb2.FloatList(value=[0.0])),
'image/object/bbox/xmax':
feature_pb2.Feature(float_list=feature_pb2.FloatList(value=[1.0])),
'image/object/bbox/ymin':
feature_pb2.Feature(float_list=feature_pb2.FloatList(value=[0.0])),
'image/object/bbox/ymax':
feature_pb2.Feature(float_list=feature_pb2.FloatList(value=[1.0])),
'image/object/class/label':
feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=[2])),
'image/object/mask':
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=flat_mask)),
}
if has_additional_channels:
features['image/additional_channels/encoded'] = feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=[encoded_additional_channels_jpeg] * 2))
example = example_pb2.Example(
features=feature_pb2.Features(
feature={
'image/encoded':
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(value=[encoded_jpeg])),
'image/format':
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=['jpeg'.encode('utf-8')])),
'image/height':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[4])),
'image/width':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[5])),
'image/object/bbox/xmin':
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
'image/object/bbox/xmax':
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0])),
'image/object/bbox/ymin':
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
'image/object/bbox/ymax':
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0])),
'image/object/class/label':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[2])),
'image/object/mask':
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=flat_mask)),
}))
features=feature_pb2.Features(feature=features))
writer.write(example.SerializeToString())
writer.close()
......@@ -218,6 +219,31 @@ class DatasetBuilderTest(tf.test.TestCase):
[2, 2, 4, 5],
output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
def test_build_tf_record_input_reader_with_additional_channels(self):
tf_record_path = self.create_tf_record(has_additional_channels=True)
input_reader_text_proto = """
shuffle: false
num_readers: 1
tf_record_input_reader {{
input_path: '{0}'
}}
""".format(tf_record_path)
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(input_reader_text_proto, input_reader_proto)
tensor_dict = dataset_util.make_initializable_iterator(
dataset_builder.build(
input_reader_proto, batch_size=2,
num_additional_channels=2)).get_next()
sv = tf.train.Supervisor(logdir=self.get_temp_dir())
with sv.prepare_or_wait_for_session() as sess:
sv.start_queue_runners(sess)
output_dict = sess.run(tensor_dict)
self.assertEquals((2, 4, 5, 5),
output_dict[fields.InputDataFields.image].shape)
def test_raises_error_with_no_input_paths(self):
input_reader_text_proto = """
shuffle: false
......
......@@ -79,12 +79,17 @@ def build(image_resizer_config):
keep_aspect_ratio_config.max_dimension):
raise ValueError('min_dimension > max_dimension')
method = _tf_resize_method(keep_aspect_ratio_config.resize_method)
per_channel_pad_value = (0, 0, 0)
if keep_aspect_ratio_config.per_channel_pad_value:
per_channel_pad_value = tuple(keep_aspect_ratio_config.
per_channel_pad_value)
image_resizer_fn = functools.partial(
preprocessor.resize_to_range,
min_dimension=keep_aspect_ratio_config.min_dimension,
max_dimension=keep_aspect_ratio_config.max_dimension,
method=method,
pad_to_max_dimension=keep_aspect_ratio_config.pad_to_max_dimension)
pad_to_max_dimension=keep_aspect_ratio_config.pad_to_max_dimension,
per_channel_pad_value=per_channel_pad_value)
if not keep_aspect_ratio_config.convert_to_grayscale:
return image_resizer_fn
elif image_resizer_oneof == 'fixed_shape_resizer':
......
......@@ -52,6 +52,9 @@ class ImageResizerBuilderTest(tf.test.TestCase):
min_dimension: 10
max_dimension: 20
pad_to_max_dimension: true
per_channel_pad_value: 3
per_channel_pad_value: 4
per_channel_pad_value: 5
}
"""
input_shape = (50, 25, 3)
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A function to build a DetectionModel from configuration."""
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import losses_builder
from object_detection.builders import matcher_builder
from object_detection.builders import post_processing_builder
from object_detection.builders import region_similarity_calculator_builder as sim_calc
from object_detection.core import box_predictor
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
from object_detection.protos import model_pb2
# A map of names to SSD feature extractors.
SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
'ssd_inception_v2': SSDInceptionV2FeatureExtractor,
'ssd_inception_v3': SSDInceptionV3FeatureExtractor,
'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor,
'ssd_mobilenet_v2': SSDMobileNetV2FeatureExtractor,
'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor,
}
# A map of names to Faster R-CNN feature extractors.
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
'faster_rcnn_nas':
frcnn_nas.FasterRCNNNASFeatureExtractor,
'faster_rcnn_pnas':
frcnn_pnas.FasterRCNNPNASFeatureExtractor,
'faster_rcnn_inception_resnet_v2':
frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
'faster_rcnn_inception_v2':
frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor,
'faster_rcnn_resnet50':
frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
'faster_rcnn_resnet101':
frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
'faster_rcnn_resnet152':
frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor,
}
def build(model_config, is_training, add_summaries=True,
add_background_class=True):
"""Builds a DetectionModel based on the model config.
Args:
model_config: A model.proto object containing the config for the desired
DetectionModel.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tensorflow summaries in the model graph.
add_background_class: Whether to add an implicit background class to one-hot
encodings of groundtruth labels. Set to false if using groundtruth labels
with an explicit background class or using multiclass scores instead of
truth in the case of distillation. Ignored in the case of faster_rcnn.
Returns:
DetectionModel based on the config.
Raises:
ValueError: On invalid meta architecture or model.
"""
if not isinstance(model_config, model_pb2.DetectionModel):
raise ValueError('model_config not of type model_pb2.DetectionModel.')
meta_architecture = model_config.WhichOneof('model')
if meta_architecture == 'ssd':
return _build_ssd_model(model_config.ssd, is_training, add_summaries,
add_background_class)
if meta_architecture == 'faster_rcnn':
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
add_summaries)
raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
def _build_ssd_feature_extractor(feature_extractor_config, is_training,
reuse_weights=None):
"""Builds a ssd_meta_arch.SSDFeatureExtractor based on config.
Args:
feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
is_training: True if this feature extractor is being built for training.
reuse_weights: if the feature extractor should reuse weights.
Returns:
ssd_meta_arch.SSDFeatureExtractor based on config.
Raises:
ValueError: On invalid feature extractor type.
"""
feature_type = feature_extractor_config.type
depth_multiplier = feature_extractor_config.depth_multiplier
min_depth = feature_extractor_config.min_depth
pad_to_multiple = feature_extractor_config.pad_to_multiple
use_explicit_padding = feature_extractor_config.use_explicit_padding
use_depthwise = feature_extractor_config.use_depthwise
conv_hyperparams = hyperparams_builder.build(
feature_extractor_config.conv_hyperparams, is_training)
override_base_feature_extractor_hyperparams = (
feature_extractor_config.override_base_feature_extractor_hyperparams)
if feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP:
raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))
feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
return feature_extractor_class(
is_training, depth_multiplier, min_depth, pad_to_multiple,
conv_hyperparams, reuse_weights, use_explicit_padding, use_depthwise,
override_base_feature_extractor_hyperparams)
def _build_ssd_model(ssd_config, is_training, add_summaries,
add_background_class=True):
"""Builds an SSD detection model based on the model config.
Args:
ssd_config: A ssd.proto object containing the config for the desired
SSDMetaArch.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
add_background_class: Whether to add an implicit background class to one-hot
encodings of groundtruth labels. Set to false if using groundtruth labels
with an explicit background class or using multiclass scores instead of
truth in the case of distillation.
Returns:
SSDMetaArch based on the config.
Raises:
ValueError: If ssd_config.type is not recognized (i.e. not registered in
model_class_map).
"""
num_classes = ssd_config.num_classes
# Feature extractor
feature_extractor = _build_ssd_feature_extractor(
feature_extractor_config=ssd_config.feature_extractor,
is_training=is_training)
box_coder = box_coder_builder.build(ssd_config.box_coder)
matcher = matcher_builder.build(ssd_config.matcher)
region_similarity_calculator = sim_calc.build(
ssd_config.similarity_calculator)
encode_background_as_zeros = ssd_config.encode_background_as_zeros
negative_class_weight = ssd_config.negative_class_weight
ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build,
ssd_config.box_predictor,
is_training, num_classes)
anchor_generator = anchor_generator_builder.build(
ssd_config.anchor_generator)
image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer)
non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
ssd_config.post_processing)
(classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner,
random_example_sampler) = losses_builder.build(ssd_config.loss)
normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
return ssd_meta_arch.SSDMetaArch(
is_training,
anchor_generator,
ssd_box_predictor,
box_coder,
feature_extractor,
matcher,
region_similarity_calculator,
encode_background_as_zeros,
negative_class_weight,
image_resizer_fn,
non_max_suppression_fn,
score_conversion_fn,
classification_loss,
localization_loss,
classification_weight,
localization_weight,
normalize_loss_by_num_matches,
hard_example_miner,
add_summaries=add_summaries,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=ssd_config.freeze_batchnorm,
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
add_background_class=add_background_class,
random_example_sampler=random_example_sampler)
def _build_faster_rcnn_feature_extractor(
feature_extractor_config, is_training, reuse_weights=None,
inplace_batchnorm_update=False):
"""Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
Args:
feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
faster_rcnn.proto.
is_training: True if this feature extractor is being built for training.
reuse_weights: if the feature extractor should reuse weights.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs. When
this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Returns:
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
Raises:
ValueError: On invalid feature extractor type.
"""
if inplace_batchnorm_update:
raise ValueError('inplace batchnorm updates not supported.')
feature_type = feature_extractor_config.type
first_stage_features_stride = (
feature_extractor_config.first_stage_features_stride)
batch_norm_trainable = feature_extractor_config.batch_norm_trainable
if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP:
raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
feature_type))
feature_extractor_class = FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP[
feature_type]
return feature_extractor_class(
is_training, first_stage_features_stride,
batch_norm_trainable, reuse_weights)
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
"""Builds a Faster R-CNN or R-FCN detection model based on the model config.
Builds R-FCN model if the second_stage_box_predictor in the config is of type
`rfcn_box_predictor` else builds a Faster R-CNN model.
Args:
frcnn_config: A faster_rcnn.proto object containing the config for the
desired FasterRCNNMetaArch or RFCNMetaArch.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns:
FasterRCNNMetaArch based on the config.
Raises:
ValueError: If frcnn_config.type is not recognized (i.e. not registered in
model_class_map).
"""
num_classes = frcnn_config.num_classes
image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)
feature_extractor = _build_faster_rcnn_feature_extractor(
frcnn_config.feature_extractor, is_training,
frcnn_config.inplace_batchnorm_update)
number_of_stages = frcnn_config.number_of_stages
first_stage_anchor_generator = anchor_generator_builder.build(
frcnn_config.first_stage_anchor_generator)
first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
first_stage_box_predictor_kernel_size = (
frcnn_config.first_stage_box_predictor_kernel_size)
first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
first_stage_positive_balance_fraction = (
frcnn_config.first_stage_positive_balance_fraction)
first_stage_nms_score_threshold = frcnn_config.first_stage_nms_score_threshold
first_stage_nms_iou_threshold = frcnn_config.first_stage_nms_iou_threshold
first_stage_max_proposals = frcnn_config.first_stage_max_proposals
first_stage_loc_loss_weight = (
frcnn_config.first_stage_localization_loss_weight)
first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight
initial_crop_size = frcnn_config.initial_crop_size
maxpool_kernel_size = frcnn_config.maxpool_kernel_size
maxpool_stride = frcnn_config.maxpool_stride
second_stage_box_predictor = box_predictor_builder.build(
hyperparams_builder.build,
frcnn_config.second_stage_box_predictor,
is_training=is_training,
num_classes=num_classes)
second_stage_batch_size = frcnn_config.second_stage_batch_size
second_stage_balance_fraction = frcnn_config.second_stage_balance_fraction
(second_stage_non_max_suppression_fn, second_stage_score_conversion_fn
) = post_processing_builder.build(frcnn_config.second_stage_post_processing)
second_stage_localization_loss_weight = (
frcnn_config.second_stage_localization_loss_weight)
second_stage_classification_loss = (
losses_builder.build_faster_rcnn_classification_loss(
frcnn_config.second_stage_classification_loss))
second_stage_classification_loss_weight = (
frcnn_config.second_stage_classification_loss_weight)
second_stage_mask_prediction_loss_weight = (
frcnn_config.second_stage_mask_prediction_loss_weight)
hard_example_miner = None
if frcnn_config.HasField('hard_example_miner'):
hard_example_miner = losses_builder.build_hard_example_miner(
frcnn_config.hard_example_miner,
second_stage_classification_loss_weight,
second_stage_localization_loss_weight)
common_kwargs = {
'is_training': is_training,
'num_classes': num_classes,
'image_resizer_fn': image_resizer_fn,
'feature_extractor': feature_extractor,
'number_of_stages': number_of_stages,
'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope_fn':
first_stage_box_predictor_arg_scope_fn,
'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
'first_stage_minibatch_size': first_stage_minibatch_size,
'first_stage_positive_balance_fraction':
first_stage_positive_balance_fraction,
'first_stage_nms_score_threshold': first_stage_nms_score_threshold,
'first_stage_nms_iou_threshold': first_stage_nms_iou_threshold,
'first_stage_max_proposals': first_stage_max_proposals,
'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
'second_stage_batch_size': second_stage_batch_size,
'second_stage_balance_fraction': second_stage_balance_fraction,
'second_stage_non_max_suppression_fn':
second_stage_non_max_suppression_fn,
'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
'second_stage_localization_loss_weight':
second_stage_localization_loss_weight,
'second_stage_classification_loss':
second_stage_classification_loss,
'second_stage_classification_loss_weight':
second_stage_classification_loss_weight,
'hard_example_miner': hard_example_miner,
'add_summaries': add_summaries}
if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor):
return rfcn_meta_arch.RFCNMetaArch(
second_stage_rfcn_box_predictor=second_stage_box_predictor,
**common_kwargs)
else:
return faster_rcnn_meta_arch.FasterRCNNMetaArch(
initial_crop_size=initial_crop_size,
maxpool_kernel_size=maxpool_kernel_size,
maxpool_stride=maxpool_stride,
second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
second_stage_mask_prediction_loss_weight=(
second_stage_mask_prediction_loss_weight),
**common_kwargs)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.models.model_builder."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import model_builder
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
from object_detection.protos import model_pb2
FRCNN_RESNET_FEAT_MAPS = {
'faster_rcnn_resnet50':
frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
'faster_rcnn_resnet101':
frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
'faster_rcnn_resnet152':
frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor
}
SSD_RESNET_V1_FPN_FEAT_MAPS = {
'ssd_resnet50_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor
}
class ModelBuilderTest(tf.test.TestCase):
def create_model(self, model_config):
"""Builds a DetectionModel based on the model config.
Args:
model_config: A model.proto object containing the config for the desired
DetectionModel.
Returns:
DetectionModel based on the config.
"""
return model_builder.build(model_config, is_training=True)
def test_create_ssd_inception_v2_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'ssd_inception_v2'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
override_base_feature_extractor_hyperparams: true
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = self.create_model(model_proto)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor,
SSDInceptionV2FeatureExtractor)
def test_create_ssd_inception_v3_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'ssd_inception_v3'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
override_base_feature_extractor_hyperparams: true
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = self.create_model(model_proto)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor,
SSDInceptionV3FeatureExtractor)
def test_create_ssd_resnet_v1_fpn_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'ssd_resnet50_v1_fpn'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
encode_background_as_zeros: true
anchor_generator {
multiscale_anchor_generator {
aspect_ratios: [1.0, 2.0, 0.5]
scales_per_octave: 2
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
weight_shared_convolutional_box_predictor {
depth: 32
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
random_normal_initializer {
}
}
}
num_layers_before_predictor: 1
}
}
normalize_loss_by_num_matches: true
normalize_loc_loss_by_codesize: true
loss {
classification_loss {
weighted_sigmoid_focal {
alpha: 0.25
gamma: 2.0
}
}
localization_loss {
weighted_smooth_l1 {
delta: 0.1
}
}
classification_weight: 1.0
localization_weight: 1.0
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in SSD_RESNET_V1_FPN_FEAT_MAPS.items():
model_proto.ssd.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor, extractor_class)
def test_create_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """
ssd {
freeze_batchnorm: true
inplace_batchnorm_update: true
feature_extractor {
type: 'ssd_mobilenet_v1'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
normalize_loc_loss_by_codesize: true
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = self.create_model(model_proto)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor,
SSDMobileNetV1FeatureExtractor)
self.assertTrue(model._normalize_loc_loss_by_codesize)
self.assertTrue(model._freeze_batchnorm)
self.assertTrue(model._inplace_batchnorm_update)
def test_create_ssd_mobilenet_v2_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'ssd_mobilenet_v2'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
normalize_loc_loss_by_codesize: true
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = self.create_model(model_proto)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor,
SSDMobileNetV2FeatureExtractor)
self.assertTrue(model._normalize_loc_loss_by_codesize)
def test_create_embedded_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'embedded_ssd_mobilenet_v1'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 256
width: 256
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = self.create_model(model_proto)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor,
EmbeddedSSDMobileNetV1FeatureExtractor)
def test_create_faster_rcnn_resnet_v1_models_from_config(self):
model_text_proto = """
faster_rcnn {
inplace_batchnorm_update: true
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_resnet101'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in FRCNN_RESNET_FEAT_MAPS.items():
model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
self.assertIsInstance(model._feature_extractor, extractor_class)
def test_create_faster_rcnn_resnet101_with_mask_prediction_enabled(self):
model_text_proto = """
faster_rcnn {
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_resnet101'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
predict_instance_masks: true
}
}
second_stage_mask_prediction_loss_weight: 3.0
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = model_builder.build(model_proto, is_training=True)
self.assertAlmostEqual(model._second_stage_mask_loss_weight, 3.0)
def test_create_faster_rcnn_nas_model_from_config(self):
model_text_proto = """
faster_rcnn {
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_nas'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 17
maxpool_kernel_size: 1
maxpool_stride: 1
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
self.assertIsInstance(
model._feature_extractor,
frcnn_nas.FasterRCNNNASFeatureExtractor)
def test_create_faster_rcnn_pnas_model_from_config(self):
model_text_proto = """
faster_rcnn {
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_pnas'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 17
maxpool_kernel_size: 1
maxpool_stride: 1
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
self.assertIsInstance(
model._feature_extractor,
frcnn_pnas.FasterRCNNPNASFeatureExtractor)
def test_create_faster_rcnn_inception_resnet_v2_model_from_config(self):
model_text_proto = """
faster_rcnn {
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_inception_resnet_v2'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 17
maxpool_kernel_size: 1
maxpool_stride: 1
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
self.assertIsInstance(
model._feature_extractor,
frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor)
def test_create_faster_rcnn_inception_v2_model_from_config(self):
model_text_proto = """
faster_rcnn {
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_inception_v2'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
self.assertIsInstance(model._feature_extractor,
frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor)
def test_create_faster_rcnn_model_from_config_with_example_miner(self):
model_text_proto = """
faster_rcnn {
num_classes: 3
feature_extractor {
type: 'faster_rcnn_inception_resnet_v2'
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
hard_example_miner {
num_hard_examples: 10
iou_threshold: 0.99
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = model_builder.build(model_proto, is_training=True)
self.assertIsNotNone(model._hard_example_miner)
def test_create_rfcn_resnet_v1_model_from_config(self):
model_text_proto = """
faster_rcnn {
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_resnet101'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 14
maxpool_kernel_size: 2
maxpool_stride: 2
second_stage_box_predictor {
rfcn_box_predictor {
conv_hyperparams {
op: CONV
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in FRCNN_RESNET_FEAT_MAPS.items():
model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch)
self.assertIsInstance(model._feature_extractor, extractor_class)
if __name__ == '__main__':
tf.test.main()
......@@ -778,7 +778,7 @@ def to_absolute_coordinates(boxlist,
height,
width,
check_range=True,
maximum_normalized_coordinate=1.01,
maximum_normalized_coordinate=1.1,
scope=None):
"""Converts normalized box coordinates to absolute pixel coordinates.
......@@ -792,7 +792,7 @@ def to_absolute_coordinates(boxlist,
width: Maximum value for width of absolute box coordinates.
check_range: If True, checks if the coordinates are normalized or not.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.01.
as normalized, default to 1.1.
scope: name scope.
Returns:
......
......@@ -931,6 +931,21 @@ class CoordinatesConversionTest(tf.test.TestCase):
out = sess.run(boxlist.get())
self.assertAllClose(out, coordinates)
def test_to_absolute_coordinates_maximum_coordinate_check(self):
coordinates = tf.constant([[0, 0, 1.2, 1.2],
[0.25, 0.25, 0.75, 0.75]], tf.float32)
img = tf.ones((128, 100, 100, 3))
boxlist = box_list.BoxList(coordinates)
absolute_boxlist = box_list_ops.to_absolute_coordinates(
boxlist,
tf.shape(img)[1],
tf.shape(img)[2],
maximum_normalized_coordinate=1.1)
with self.test_session() as sess:
with self.assertRaisesOpError('assertion failed'):
sess.run(absolute_boxlist.get())
class BoxRefinementTest(tf.test.TestCase):
......
......@@ -79,10 +79,12 @@ class BoxPredictor(object):
Returns:
A dictionary containing at least the following tensors.
box_encodings: A list of float tensors of shape
[batch_size, num_anchors_i, q, code_size] representing the location of
the objects, where q is 1 or the number of classes. Each entry in the
list corresponds to a feature map in the input `image_features` list.
box_encodings: A list of float tensors. Each entry in the list
corresponds to a feature map in the input `image_features` list. All
tensors in the list have one of the two following shapes:
a. [batch_size, num_anchors_i, q, code_size] representing the location
of the objects, where q is 1 or the number of classes.
b. [batch_size, num_anchors_i, code_size].
class_predictions_with_background: A list of float tensors of shape
[batch_size, num_anchors_i, num_classes + 1] representing the class
predictions for the proposals. Each entry in the list corresponds to a
......@@ -120,10 +122,12 @@ class BoxPredictor(object):
Returns:
A dictionary containing at least the following tensors.
box_encodings: A list of float tensors of shape
[batch_size, num_anchors_i, q, code_size] representing the location of
the objects, where q is 1 or the number of classes. Each entry in the
list corresponds to a feature map in the input `image_features` list.
box_encodings: A list of float tensors. Each entry in the list
corresponds to a feature map in the input `image_features` list. All
tensors in the list have one of the two following shapes:
a. [batch_size, num_anchors_i, q, code_size] representing the location
of the objects, where q is 1 or the number of classes.
b. [batch_size, num_anchors_i, code_size].
class_predictions_with_background: A list of float tensors of shape
[batch_size, num_anchors_i, num_classes + 1] representing the class
predictions for the proposals. Each entry in the list corresponds to a
......@@ -765,6 +769,13 @@ class ConvolutionalBoxPredictor(BoxPredictor):
}
# TODO(rathodv): Replace with slim.arg_scope_func_key once its available
# externally.
def _arg_scope_func_key(op):
"""Returns a key that can be used to index arg_scope dictionary."""
return getattr(op, '_key_op', str(op))
# TODO(rathodv): Merge the implementation with ConvolutionalBoxPredictor above
# since they are very similar.
class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
......@@ -773,8 +784,12 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
Defines the box predictor as defined in
https://arxiv.org/abs/1708.02002. This class differs from
ConvolutionalBoxPredictor in that it shares weights and biases while
predicting from different feature maps. Separate multi-layer towers are
constructed for the box encoding and class predictors respectively.
predicting from different feature maps. However, batch_norm parameters are not
shared because the statistics of the activations vary among the different
feature maps.
Also note that separate multi-layer towers are constructed for the box
encoding and class predictors respectively.
"""
def __init__(self,
......@@ -833,14 +848,15 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
Returns:
box_encodings: A list of float tensors of shape
[batch_size, num_anchors_i, q, code_size] representing the location of
the objects, where q is 1 or the number of classes. Each entry in the
list corresponds to a feature map in the input `image_features` list.
[batch_size, num_anchors_i, code_size] representing the location of
the objects. Each entry in the list corresponds to a feature map in the
input `image_features` list.
class_predictions_with_background: A list of float tensors of shape
[batch_size, num_anchors_i, num_classes + 1] representing the class
predictions for the proposals. Each entry in the list corresponds to a
feature map in the input `image_features` list.
Raises:
ValueError: If the image feature maps do not have the same number of
channels or if the num predictions per locations is differs between the
......@@ -858,15 +874,18 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
'channels, found: {}'.format(feature_channels))
box_encodings_list = []
class_predictions_list = []
for (image_feature, num_predictions_per_location) in zip(
image_features, num_predictions_per_location_list):
for feature_index, (image_feature,
num_predictions_per_location) in enumerate(
zip(image_features,
num_predictions_per_location_list)):
# Add a slot for the background class.
with tf.variable_scope('WeightSharedConvolutionalBoxPredictor',
reuse=tf.AUTO_REUSE):
num_class_slots = self.num_classes + 1
box_encodings_net = image_feature
class_predictions_net = image_feature
with slim.arg_scope(self._conv_hyperparams_fn()):
with slim.arg_scope(self._conv_hyperparams_fn()) as sc:
apply_batch_norm = _arg_scope_func_key(slim.batch_norm) in sc
for i in range(self._num_layers_before_predictor):
box_encodings_net = slim.conv2d(
box_encodings_net,
......@@ -874,14 +893,22 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
[self._kernel_size, self._kernel_size],
stride=1,
padding='SAME',
scope='BoxEncodingPredictionTower/conv2d_{}'.format(i))
activation_fn=None,
normalizer_fn=(tf.identity if apply_batch_norm else None),
scope='BoxPredictionTower/conv2d_{}'.format(i))
if apply_batch_norm:
box_encodings_net = slim.batch_norm(
box_encodings_net,
scope='BoxPredictionTower/conv2d_{}/BatchNorm/feature_{}'.
format(i, feature_index))
box_encodings_net = tf.nn.relu6(box_encodings_net)
box_encodings = slim.conv2d(
box_encodings_net,
num_predictions_per_location * self._box_code_size,
[self._kernel_size, self._kernel_size],
activation_fn=None, stride=1, padding='SAME',
normalizer_fn=None,
scope='BoxEncodingPredictor')
scope='BoxPredictor')
for i in range(self._num_layers_before_predictor):
class_predictions_net = slim.conv2d(
......@@ -890,7 +917,15 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
[self._kernel_size, self._kernel_size],
stride=1,
padding='SAME',
activation_fn=None,
normalizer_fn=(tf.identity if apply_batch_norm else None),
scope='ClassPredictionTower/conv2d_{}'.format(i))
if apply_batch_norm:
class_predictions_net = slim.batch_norm(
class_predictions_net,
scope='ClassPredictionTower/conv2d_{}/BatchNorm/feature_{}'
.format(i, feature_index))
class_predictions_net = tf.nn.relu6(class_predictions_net)
if self._use_dropout:
class_predictions_net = slim.dropout(
class_predictions_net, keep_prob=self._dropout_keep_prob)
......@@ -912,7 +947,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
combined_feature_map_shape[1] *
combined_feature_map_shape[2] *
num_predictions_per_location,
1, self._box_code_size]))
self._box_code_size]))
box_encodings_list.append(box_encodings)
class_predictions_with_background = tf.reshape(
class_predictions_with_background,
......
......@@ -442,6 +442,24 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams)
return hyperparams_builder.build(conv_hyperparams, is_training=True)
def _build_conv_arg_scope_no_batch_norm(self):
conv_hyperparams = hyperparams_pb2.Hyperparams()
conv_hyperparams_text_proto = """
activation: RELU_6
regularizer {
l2_regularizer {
}
}
initializer {
random_normal_initializer {
stddev: 0.01
mean: 0.0
}
}
"""
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams)
return hyperparams_builder.build(conv_hyperparams, is_training=True)
def test_get_boxes_for_five_aspect_ratios_per_location(self):
def graph_fn(image_features):
......@@ -463,7 +481,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
image_features = np.random.rand(4, 8, 8, 64).astype(np.float32)
(box_encodings, objectness_predictions) = self.execute(
graph_fn, [image_features])
self.assertAllEqual(box_encodings.shape, [4, 320, 1, 4])
self.assertAllEqual(box_encodings.shape, [4, 320, 4])
self.assertAllEqual(objectness_predictions.shape, [4, 320, 1])
def test_bias_predictions_to_background_with_sigmoid_score_conversion(self):
......@@ -512,7 +530,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
image_features = np.random.rand(4, 8, 8, 64).astype(np.float32)
(box_encodings, class_predictions_with_background) = self.execute(
graph_fn, [image_features])
self.assertAllEqual(box_encodings.shape, [4, 320, 1, 4])
self.assertAllEqual(box_encodings.shape, [4, 320, 4])
self.assertAllEqual(class_predictions_with_background.shape,
[4, 320, num_classes_without_background+1])
......@@ -543,11 +561,12 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
image_features2 = np.random.rand(4, 8, 8, 64).astype(np.float32)
(box_encodings, class_predictions_with_background) = self.execute(
graph_fn, [image_features1, image_features2])
self.assertAllEqual(box_encodings.shape, [4, 640, 1, 4])
self.assertAllEqual(box_encodings.shape, [4, 640, 4])
self.assertAllEqual(class_predictions_with_background.shape,
[4, 640, num_classes_without_background+1])
def test_predictions_from_multiple_feature_maps_share_weights(self):
def test_predictions_from_multiple_feature_maps_share_weights_not_batchnorm(
self):
num_classes_without_background = 6
def graph_fn(image_features1, image_features2):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
......@@ -574,26 +593,95 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
actual_variable_set = set(
[var.op.name for var in tf.trainable_variables()])
expected_variable_set = set([
# Box prediction tower
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxPredictionTower/conv2d_0/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxPredictionTower/conv2d_0/BatchNorm/feature_0/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictionTower/conv2d_0/weights'),
'BoxPredictionTower/conv2d_0/BatchNorm/feature_1/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictionTower/conv2d_0/BatchNorm/beta'),
'BoxPredictionTower/conv2d_1/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictionTower/conv2d_1/weights'),
'BoxPredictionTower/conv2d_1/BatchNorm/feature_0/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictionTower/conv2d_1/BatchNorm/beta'),
'BoxPredictionTower/conv2d_1/BatchNorm/feature_1/beta'),
# Box prediction head
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxPredictor/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxPredictor/biases'),
# Class prediction tower
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_0/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_0/BatchNorm/beta'),
'ClassPredictionTower/conv2d_0/BatchNorm/feature_0/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_0/BatchNorm/feature_1/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_1/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_1/BatchNorm/beta'),
'ClassPredictionTower/conv2d_1/BatchNorm/feature_0/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_1/BatchNorm/feature_1/beta'),
# Class prediction head
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictor/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictor/biases')])
self.assertEqual(expected_variable_set, actual_variable_set)
def test_no_batchnorm_params_when_batchnorm_is_not_configured(self):
num_classes_without_background = 6
def graph_fn(image_features1, image_features2):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_conv_arg_scope_no_batch_norm(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4)
box_predictions = conv_box_predictor.predict(
[image_features1, image_features2],
num_predictions_per_location=[5, 5],
scope='BoxPredictor')
box_encodings = tf.concat(
box_predictions[box_predictor.BOX_ENCODINGS], axis=1)
class_predictions_with_background = tf.concat(
box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND],
axis=1)
return (box_encodings, class_predictions_with_background)
with self.test_session(graph=tf.Graph()):
graph_fn(tf.random_uniform([4, 32, 32, 3], dtype=tf.float32),
tf.random_uniform([4, 16, 16, 3], dtype=tf.float32))
actual_variable_set = set(
[var.op.name for var in tf.trainable_variables()])
expected_variable_set = set([
# Box prediction tower
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxPredictionTower/conv2d_0/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxPredictionTower/conv2d_0/biases'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxPredictionTower/conv2d_1/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictor/weights'),
'BoxPredictionTower/conv2d_1/biases'),
# Box prediction head
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictor/biases'),
'BoxPredictor/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxPredictor/biases'),
# Class prediction tower
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_0/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_0/biases'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_1/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_1/biases'),
# Class prediction head
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictor/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
......@@ -628,7 +716,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
[tf.shape(box_encodings), tf.shape(objectness_predictions)],
feed_dict={image_features:
np.random.rand(4, resolution, resolution, 64)})
self.assertAllEqual(box_encodings_shape, [4, expected_num_anchors, 1, 4])
self.assertAllEqual(box_encodings_shape, [4, expected_num_anchors, 4])
self.assertAllEqual(objectness_predictions_shape,
[4, expected_num_anchors, 1])
......
......@@ -2128,7 +2128,8 @@ def resize_to_range(image,
max_dimension=None,
method=tf.image.ResizeMethod.BILINEAR,
align_corners=False,
pad_to_max_dimension=False):
pad_to_max_dimension=False,
per_channel_pad_value=(0, 0, 0)):
"""Resizes an image so its dimensions are within the provided value.
The output size can be described by two cases:
......@@ -2153,6 +2154,8 @@ def resize_to_range(image,
so the resulting image is of the spatial size
[max_dimension, max_dimension]. If masks are included they are padded
similarly.
per_channel_pad_value: A tuple of per-channel scalar value to use for
padding. By default pads zeros.
Returns:
Note that the position of the resized_image_shape changes based on whether
......@@ -2181,8 +2184,20 @@ def resize_to_range(image,
image, new_size[:-1], method=method, align_corners=align_corners)
if pad_to_max_dimension:
new_image = tf.image.pad_to_bounding_box(
new_image, 0, 0, max_dimension, max_dimension)
channels = tf.unstack(new_image, axis=2)
if len(channels) != len(per_channel_pad_value):
raise ValueError('Number of channels must be equal to the length of '
'per-channel pad value.')
new_image = tf.stack(
[
tf.pad(
channels[i], [[0, max_dimension - new_size[0]],
[0, max_dimension - new_size[1]]],
constant_values=per_channel_pad_value[i])
for i in range(len(channels))
],
axis=2)
new_image.set_shape([max_dimension, max_dimension, 3])
result = [new_image]
if masks is not None:
......
......@@ -2316,6 +2316,46 @@ class PreprocessorTest(tf.test.TestCase):
np.random.randn(*in_shape)})
self.assertAllEqual(out_image_shape, expected_shape)
def testResizeToRangeWithPadToMaxDimensionReturnsCorrectShapes(self):
in_shape_list = [[60, 40, 3], [15, 30, 3], [15, 50, 3]]
min_dim = 50
max_dim = 100
expected_shape_list = [[100, 100, 3], [100, 100, 3], [100, 100, 3]]
for in_shape, expected_shape in zip(in_shape_list, expected_shape_list):
in_image = tf.placeholder(tf.float32, shape=(None, None, 3))
out_image, _ = preprocessor.resize_to_range(
in_image,
min_dimension=min_dim,
max_dimension=max_dim,
pad_to_max_dimension=True)
self.assertAllEqual(out_image.shape.as_list(), expected_shape)
out_image_shape = tf.shape(out_image)
with self.test_session() as sess:
out_image_shape = sess.run(
out_image_shape, feed_dict={in_image: np.random.randn(*in_shape)})
self.assertAllEqual(out_image_shape, expected_shape)
def testResizeToRangeWithPadToMaxDimensionReturnsCorrectTensor(self):
in_image_np = np.array([[[0, 1, 2]]], np.float32)
ex_image_np = np.array(
[[[0, 1, 2], [123.68, 116.779, 103.939]],
[[123.68, 116.779, 103.939], [123.68, 116.779, 103.939]]], np.float32)
min_dim = 1
max_dim = 2
in_image = tf.placeholder(tf.float32, shape=(None, None, 3))
out_image, _ = preprocessor.resize_to_range(
in_image,
min_dimension=min_dim,
max_dimension=max_dim,
pad_to_max_dimension=True,
per_channel_pad_value=(123.68, 116.779, 103.939))
with self.test_session() as sess:
out_image_np = sess.run(out_image, feed_dict={in_image: in_image_np})
self.assertAllClose(ex_image_np, out_image_np)
def testResizeToRangeWithMasksPreservesStaticSpatialShape(self):
"""Tests image resizing, checking output sizes."""
in_image_shape_list = [[60, 40, 3], [15, 30, 3]]
......
......@@ -34,6 +34,7 @@ class InputDataFields(object):
Attributes:
image: image.
image_additional_channels: additional channels.
original_image: image in the original input size.
key: unique key corresponding to image.
source_id: source of the original image.
......@@ -66,6 +67,7 @@ class InputDataFields(object):
multiclass_scores: the label score per class for each box.
"""
image = 'image'
image_additional_channels = 'image_additional_channels'
original_image = 'original_image'
key = 'key'
source_id = 'source_id'
......@@ -161,6 +163,8 @@ class TfExampleFields(object):
height: height of image in pixels, e.g. 462
width: width of image in pixels, e.g. 581
source_id: original source of the image
image_class_text: image-level label in text format
image_class_label: image-level label in numerical format
object_class_text: labels in text format, e.g. ["person", "cat"]
object_class_label: labels in numbers, e.g. [16, 8]
object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30
......@@ -195,6 +199,8 @@ class TfExampleFields(object):
height = 'image/height'
width = 'image/width'
source_id = 'image/source_id'
image_class_text = 'image/class/text'
image_class_label = 'image/class/label'
object_class_text = 'image/object/class/text'
object_class_label = 'image/object/class/label'
object_bbox_ymin = 'image/object/bbox/ymin'
......
item {
name: "/m/061hd_"
id: 1
display_name: "Infant bed"
}
item {
name: "/m/06m11"
id: 2
display_name: "Rose"
}
item {
name: "/m/03120"
id: 3
display_name: "Flag"
}
item {
name: "/m/01kb5b"
id: 4
display_name: "Flashlight"
}
item {
name: "/m/0120dh"
id: 5
display_name: "Sea turtle"
}
item {
name: "/m/0dv5r"
id: 6
display_name: "Camera"
}
item {
name: "/m/0jbk"
id: 7
display_name: "Animal"
}
item {
name: "/m/0174n1"
id: 8
display_name: "Glove"
}
item {
name: "/m/09f_2"
id: 9
display_name: "Crocodile"
}
item {
name: "/m/01xq0k1"
id: 10
display_name: "Cattle"
}
item {
name: "/m/03jm5"
id: 11
display_name: "House"
}
item {
name: "/m/02g30s"
id: 12
display_name: "Guacamole"
}
item {
name: "/m/05z6w"
id: 13
display_name: "Penguin"
}
item {
name: "/m/01jfm_"
id: 14
display_name: "Vehicle registration plate"
}
item {
name: "/m/076lb9"
id: 15
display_name: "Training bench"
}
item {
name: "/m/0gj37"
id: 16
display_name: "Ladybug"
}
item {
name: "/m/0k0pj"
id: 17
display_name: "Human nose"
}
item {
name: "/m/0kpqd"
id: 18
display_name: "Watermelon"
}
item {
name: "/m/0l14j_"
id: 19
display_name: "Flute"
}
item {
name: "/m/0cyf8"
id: 20
display_name: "Butterfly"
}
item {
name: "/m/0174k2"
id: 21
display_name: "Washing machine"
}
item {
name: "/m/0dq75"
id: 22
display_name: "Raccoon"
}
item {
name: "/m/076bq"
id: 23
display_name: "Segway"
}
item {
name: "/m/07crc"
id: 24
display_name: "Taco"
}
item {
name: "/m/0d8zb"
id: 25
display_name: "Jellyfish"
}
item {
name: "/m/0fszt"
id: 26
display_name: "Cake"
}
item {
name: "/m/0k1tl"
id: 27
display_name: "Pen"
}
item {
name: "/m/020kz"
id: 28
display_name: "Cannon"
}
item {
name: "/m/09728"
id: 29
display_name: "Bread"
}
item {
name: "/m/07j7r"
id: 30
display_name: "Tree"
}
item {
name: "/m/0fbdv"
id: 31
display_name: "Shellfish"
}
item {
name: "/m/03ssj5"
id: 32
display_name: "Bed"
}
item {
name: "/m/03qrc"
id: 33
display_name: "Hamster"
}
item {
name: "/m/02dl1y"
id: 34
display_name: "Hat"
}
item {
name: "/m/01k6s3"
id: 35
display_name: "Toaster"
}
item {
name: "/m/02jfl0"
id: 36
display_name: "Sombrero"
}
item {
name: "/m/01krhy"
id: 37
display_name: "Tiara"
}
item {
name: "/m/04kkgm"
id: 38
display_name: "Bowl"
}
item {
name: "/m/0ft9s"
id: 39
display_name: "Dragonfly"
}
item {
name: "/m/0d_2m"
id: 40
display_name: "Moths and butterflies"
}
item {
name: "/m/0czz2"
id: 41
display_name: "Antelope"
}
item {
name: "/m/0f4s2w"
id: 42
display_name: "Vegetable"
}
item {
name: "/m/07dd4"
id: 43
display_name: "Torch"
}
item {
name: "/m/0cgh4"
id: 44
display_name: "Building"
}
item {
name: "/m/03bbps"
id: 45
display_name: "Power plugs and sockets"
}
item {
name: "/m/02pjr4"
id: 46
display_name: "Blender"
}
item {
name: "/m/04p0qw"
id: 47
display_name: "Billiard table"
}
item {
name: "/m/02pdsw"
id: 48
display_name: "Cutting board"
}
item {
name: "/m/01yx86"
id: 49
display_name: "Bronze sculpture"
}
item {
name: "/m/09dzg"
id: 50
display_name: "Turtle"
}
item {
name: "/m/0hkxq"
id: 51
display_name: "Broccoli"
}
item {
name: "/m/07dm6"
id: 52
display_name: "Tiger"
}
item {
name: "/m/054_l"
id: 53
display_name: "Mirror"
}
item {
name: "/m/01dws"
id: 54
display_name: "Bear"
}
item {
name: "/m/027pcv"
id: 55
display_name: "Zucchini"
}
item {
name: "/m/01d40f"
id: 56
display_name: "Dress"
}
item {
name: "/m/02rgn06"
id: 57
display_name: "Volleyball"
}
item {
name: "/m/0342h"
id: 58
display_name: "Guitar"
}
item {
name: "/m/06bt6"
id: 59
display_name: "Reptile"
}
item {
name: "/m/0323sq"
id: 60
display_name: "Golf cart"
}
item {
name: "/m/02zvsm"
id: 61
display_name: "Tart"
}
item {
name: "/m/02fq_6"
id: 62
display_name: "Fedora"
}
item {
name: "/m/01lrl"
id: 63
display_name: "Carnivore"
}
item {
name: "/m/0k4j"
id: 64
display_name: "Car"
}
item {
name: "/m/04h7h"
id: 65
display_name: "Lighthouse"
}
item {
name: "/m/07xyvk"
id: 66
display_name: "Coffeemaker"
}
item {
name: "/m/03y6mg"
id: 67
display_name: "Food processor"
}
item {
name: "/m/07r04"
id: 68
display_name: "Truck"
}
item {
name: "/m/03__z0"
id: 69
display_name: "Bookcase"
}
item {
name: "/m/019w40"
id: 70
display_name: "Surfboard"
}
item {
name: "/m/09j5n"
id: 71
display_name: "Footwear"
}
item {
name: "/m/0cvnqh"
id: 72
display_name: "Bench"
}
item {
name: "/m/01llwg"
id: 73
display_name: "Necklace"
}
item {
name: "/m/0c9ph5"
id: 74
display_name: "Flower"
}
item {
name: "/m/015x5n"
id: 75
display_name: "Radish"
}
item {
name: "/m/0gd2v"
id: 76
display_name: "Marine mammal"
}
item {
name: "/m/04v6l4"
id: 77
display_name: "Frying pan"
}
item {
name: "/m/02jz0l"
id: 78
display_name: "Tap"
}
item {
name: "/m/0dj6p"
id: 79
display_name: "Peach"
}
item {
name: "/m/04ctx"
id: 80
display_name: "Knife"
}
item {
name: "/m/080hkjn"
id: 81
display_name: "Handbag"
}
item {
name: "/m/01c648"
id: 82
display_name: "Laptop"
}
item {
name: "/m/01j61q"
id: 83
display_name: "Tent"
}
item {
name: "/m/012n7d"
id: 84
display_name: "Ambulance"
}
item {
name: "/m/025nd"
id: 85
display_name: "Christmas tree"
}
item {
name: "/m/09csl"
id: 86
display_name: "Eagle"
}
item {
name: "/m/01lcw4"
id: 87
display_name: "Limousine"
}
item {
name: "/m/0h8n5zk"
id: 88
display_name: "Kitchen & dining room table"
}
item {
name: "/m/0633h"
id: 89
display_name: "Polar bear"
}
item {
name: "/m/01fdzj"
id: 90
display_name: "Tower"
}
item {
name: "/m/01226z"
id: 91
display_name: "Football"
}
item {
name: "/m/0mw_6"
id: 92
display_name: "Willow"
}
item {
name: "/m/04hgtk"
id: 93
display_name: "Human head"
}
item {
name: "/m/02pv19"
id: 94
display_name: "Stop sign"
}
item {
name: "/m/09qck"
id: 95
display_name: "Banana"
}
item {
name: "/m/063rgb"
id: 96
display_name: "Mixer"
}
item {
name: "/m/0lt4_"
id: 97
display_name: "Binoculars"
}
item {
name: "/m/0270h"
id: 98
display_name: "Dessert"
}
item {
name: "/m/01h3n"
id: 99
display_name: "Bee"
}
item {
name: "/m/01mzpv"
id: 100
display_name: "Chair"
}
item {
name: "/m/04169hn"
id: 101
display_name: "Wood-burning stove"
}
item {
name: "/m/0fm3zh"
id: 102
display_name: "Flowerpot"
}
item {
name: "/m/0d20w4"
id: 103
display_name: "Beaker"
}
item {
name: "/m/0_cp5"
id: 104
display_name: "Oyster"
}
item {
name: "/m/01dy8n"
id: 105
display_name: "Woodpecker"
}
item {
name: "/m/03m5k"
id: 106
display_name: "Harp"
}
item {
name: "/m/03dnzn"
id: 107
display_name: "Bathtub"
}
item {
name: "/m/0h8mzrc"
id: 108
display_name: "Wall clock"
}
item {
name: "/m/0h8mhzd"
id: 109
display_name: "Sports uniform"
}
item {
name: "/m/03d443"
id: 110
display_name: "Rhinoceros"
}
item {
name: "/m/01gllr"
id: 111
display_name: "Beehive"
}
item {
name: "/m/0642b4"
id: 112
display_name: "Cupboard"
}
item {
name: "/m/09b5t"
id: 113
display_name: "Chicken"
}
item {
name: "/m/04yx4"
id: 114
display_name: "Man"
}
item {
name: "/m/01f8m5"
id: 115
display_name: "Blue jay"
}
item {
name: "/m/015x4r"
id: 116
display_name: "Cucumber"
}
item {
name: "/m/01j51"
id: 117
display_name: "Balloon"
}
item {
name: "/m/02zt3"
id: 118
display_name: "Kite"
}
item {
name: "/m/03tw93"
id: 119
display_name: "Fireplace"
}
item {
name: "/m/01jfsr"
id: 120
display_name: "Lantern"
}
item {
name: "/m/04ylt"
id: 121
display_name: "Missile"
}
item {
name: "/m/0bt_c3"
id: 122
display_name: "Book"
}
item {
name: "/m/0cmx8"
id: 123
display_name: "Spoon"
}
item {
name: "/m/0hqkz"
id: 124
display_name: "Grapefruit"
}
item {
name: "/m/071qp"
id: 125
display_name: "Squirrel"
}
item {
name: "/m/0cyhj_"
id: 126
display_name: "Orange"
}
item {
name: "/m/01xygc"
id: 127
display_name: "Coat"
}
item {
name: "/m/0420v5"
id: 128
display_name: "Punching bag"
}
item {
name: "/m/0898b"
id: 129
display_name: "Zebra"
}
item {
name: "/m/01knjb"
id: 130
display_name: "Billboard"
}
item {
name: "/m/0199g"
id: 131
display_name: "Bicycle"
}
item {
name: "/m/03c7gz"
id: 132
display_name: "Door handle"
}
item {
name: "/m/02x984l"
id: 133
display_name: "Mechanical fan"
}
item {
name: "/m/04zwwv"
id: 134
display_name: "Ring binder"
}
item {
name: "/m/04bcr3"
id: 135
display_name: "Table"
}
item {
name: "/m/0gv1x"
id: 136
display_name: "Parrot"
}
item {
name: "/m/01nq26"
id: 137
display_name: "Sock"
}
item {
name: "/m/02s195"
id: 138
display_name: "Vase"
}
item {
name: "/m/083kb"
id: 139
display_name: "Weapon"
}
item {
name: "/m/06nrc"
id: 140
display_name: "Shotgun"
}
item {
name: "/m/0jyfg"
id: 141
display_name: "Glasses"
}
item {
name: "/m/0nybt"
id: 142
display_name: "Seahorse"
}
item {
name: "/m/0176mf"
id: 143
display_name: "Belt"
}
item {
name: "/m/01rzcn"
id: 144
display_name: "Watercraft"
}
item {
name: "/m/0d4v4"
id: 145
display_name: "Window"
}
item {
name: "/m/03bk1"
id: 146
display_name: "Giraffe"
}
item {
name: "/m/096mb"
id: 147
display_name: "Lion"
}
item {
name: "/m/0h9mv"
id: 148
display_name: "Tire"
}
item {
name: "/m/07yv9"
id: 149
display_name: "Vehicle"
}
item {
name: "/m/0ph39"
id: 150
display_name: "Canoe"
}
item {
name: "/m/01rkbr"
id: 151
display_name: "Tie"
}
item {
name: "/m/0gjbg72"
id: 152
display_name: "Shelf"
}
item {
name: "/m/06z37_"
id: 153
display_name: "Picture frame"
}
item {
name: "/m/01m4t"
id: 154
display_name: "Printer"
}
item {
name: "/m/035r7c"
id: 155
display_name: "Human leg"
}
item {
name: "/m/019jd"
id: 156
display_name: "Boat"
}
item {
name: "/m/02tsc9"
id: 157
display_name: "Slow cooker"
}
item {
name: "/m/015wgc"
id: 158
display_name: "Croissant"
}
item {
name: "/m/0c06p"
id: 159
display_name: "Candle"
}
item {
name: "/m/01dwwc"
id: 160
display_name: "Pancake"
}
item {
name: "/m/034c16"
id: 161
display_name: "Pillow"
}
item {
name: "/m/0242l"
id: 162
display_name: "Coin"
}
item {
name: "/m/02lbcq"
id: 163
display_name: "Stretcher"
}
item {
name: "/m/03nfch"
id: 164
display_name: "Sandal"
}
item {
name: "/m/03bt1vf"
id: 165
display_name: "Woman"
}
item {
name: "/m/01lynh"
id: 166
display_name: "Stairs"
}
item {
name: "/m/03q5t"
id: 167
display_name: "Harpsichord"
}
item {
name: "/m/0fqt361"
id: 168
display_name: "Stool"
}
item {
name: "/m/01bjv"
id: 169
display_name: "Bus"
}
item {
name: "/m/01s55n"
id: 170
display_name: "Suitcase"
}
item {
name: "/m/0283dt1"
id: 171
display_name: "Human mouth"
}
item {
name: "/m/01z1kdw"
id: 172
display_name: "Juice"
}
item {
name: "/m/016m2d"
id: 173
display_name: "Skull"
}
item {
name: "/m/02dgv"
id: 174
display_name: "Door"
}
item {
name: "/m/07y_7"
id: 175
display_name: "Violin"
}
item {
name: "/m/01_5g"
id: 176
display_name: "Chopsticks"
}
item {
name: "/m/06_72j"
id: 177
display_name: "Digital clock"
}
item {
name: "/m/0ftb8"
id: 178
display_name: "Sunflower"
}
item {
name: "/m/0c29q"
id: 179
display_name: "Leopard"
}
item {
name: "/m/0jg57"
id: 180
display_name: "Bell pepper"
}
item {
name: "/m/02l8p9"
id: 181
display_name: "Harbor seal"
}
item {
name: "/m/078jl"
id: 182
display_name: "Snake"
}
item {
name: "/m/0llzx"
id: 183
display_name: "Sewing machine"
}
item {
name: "/m/0dbvp"
id: 184
display_name: "Goose"
}
item {
name: "/m/09ct_"
id: 185
display_name: "Helicopter"
}
item {
name: "/m/0dkzw"
id: 186
display_name: "Seat belt"
}
item {
name: "/m/02p5f1q"
id: 187
display_name: "Coffee cup"
}
item {
name: "/m/0fx9l"
id: 188
display_name: "Microwave oven"
}
item {
name: "/m/01b9xk"
id: 189
display_name: "Hot dog"
}
item {
name: "/m/0b3fp9"
id: 190
display_name: "Countertop"
}
item {
name: "/m/0h8n27j"
id: 191
display_name: "Serving tray"
}
item {
name: "/m/0h8n6f9"
id: 192
display_name: "Dog bed"
}
item {
name: "/m/01599"
id: 193
display_name: "Beer"
}
item {
name: "/m/017ftj"
id: 194
display_name: "Sunglasses"
}
item {
name: "/m/044r5d"
id: 195
display_name: "Golf ball"
}
item {
name: "/m/01dwsz"
id: 196
display_name: "Waffle"
}
item {
name: "/m/0cdl1"
id: 197
display_name: "Palm tree"
}
item {
name: "/m/07gql"
id: 198
display_name: "Trumpet"
}
item {
name: "/m/0hdln"
id: 199
display_name: "Ruler"
}
item {
name: "/m/0zvk5"
id: 200
display_name: "Helmet"
}
item {
name: "/m/012w5l"
id: 201
display_name: "Ladder"
}
item {
name: "/m/021sj1"
id: 202
display_name: "Office building"
}
item {
name: "/m/0bh9flk"
id: 203
display_name: "Tablet computer"
}
item {
name: "/m/09gtd"
id: 204
display_name: "Toilet paper"
}
item {
name: "/m/0jwn_"
id: 205
display_name: "Pomegranate"
}
item {
name: "/m/02wv6h6"
id: 206
display_name: "Skirt"
}
item {
name: "/m/02wv84t"
id: 207
display_name: "Gas stove"
}
item {
name: "/m/021mn"
id: 208
display_name: "Cookie"
}
item {
name: "/m/018p4k"
id: 209
display_name: "Cart"
}
item {
name: "/m/06j2d"
id: 210
display_name: "Raven"
}
item {
name: "/m/033cnk"
id: 211
display_name: "Egg"
}
item {
name: "/m/01j3zr"
id: 212
display_name: "Burrito"
}
item {
name: "/m/03fwl"
id: 213
display_name: "Goat"
}
item {
name: "/m/058qzx"
id: 214
display_name: "Kitchen knife"
}
item {
name: "/m/06_fw"
id: 215
display_name: "Skateboard"
}
item {
name: "/m/02x8cch"
id: 216
display_name: "Salt and pepper shakers"
}
item {
name: "/m/04g2r"
id: 217
display_name: "Lynx"
}
item {
name: "/m/01b638"
id: 218
display_name: "Boot"
}
item {
name: "/m/099ssp"
id: 219
display_name: "Platter"
}
item {
name: "/m/071p9"
id: 220
display_name: "Ski"
}
item {
name: "/m/01gkx_"
id: 221
display_name: "Swimwear"
}
item {
name: "/m/0b_rs"
id: 222
display_name: "Swimming pool"
}
item {
name: "/m/03v5tg"
id: 223
display_name: "Drinking straw"
}
item {
name: "/m/01j5ks"
id: 224
display_name: "Wrench"
}
item {
name: "/m/026t6"
id: 225
display_name: "Drum"
}
item {
name: "/m/0_k2"
id: 226
display_name: "Ant"
}
item {
name: "/m/039xj_"
id: 227
display_name: "Human ear"
}
item {
name: "/m/01b7fy"
id: 228
display_name: "Headphones"
}
item {
name: "/m/0220r2"
id: 229
display_name: "Fountain"
}
item {
name: "/m/015p6"
id: 230
display_name: "Bird"
}
item {
name: "/m/0fly7"
id: 231
display_name: "Jeans"
}
item {
name: "/m/07c52"
id: 232
display_name: "Television"
}
item {
name: "/m/0n28_"
id: 233
display_name: "Crab"
}
item {
name: "/m/0hg7b"
id: 234
display_name: "Microphone"
}
item {
name: "/m/019dx1"
id: 235
display_name: "Home appliance"
}
item {
name: "/m/04vv5k"
id: 236
display_name: "Snowplow"
}
item {
name: "/m/020jm"
id: 237
display_name: "Beetle"
}
item {
name: "/m/047v4b"
id: 238
display_name: "Artichoke"
}
item {
name: "/m/01xs3r"
id: 239
display_name: "Jet ski"
}
item {
name: "/m/03kt2w"
id: 240
display_name: "Stationary bicycle"
}
item {
name: "/m/03q69"
id: 241
display_name: "Human hair"
}
item {
name: "/m/01dxs"
id: 242
display_name: "Brown bear"
}
item {
name: "/m/01h8tj"
id: 243
display_name: "Starfish"
}
item {
name: "/m/0dt3t"
id: 244
display_name: "Fork"
}
item {
name: "/m/0cjq5"
id: 245
display_name: "Lobster"
}
item {
name: "/m/0h8lkj8"
id: 246
display_name: "Corded phone"
}
item {
name: "/m/0271t"
id: 247
display_name: "Drink"
}
item {
name: "/m/03q5c7"
id: 248
display_name: "Saucer"
}
item {
name: "/m/0fj52s"
id: 249
display_name: "Carrot"
}
item {
name: "/m/03vt0"
id: 250
display_name: "Insect"
}
item {
name: "/m/01x3z"
id: 251
display_name: "Clock"
}
item {
name: "/m/0d5gx"
id: 252
display_name: "Castle"
}
item {
name: "/m/0h8my_4"
id: 253
display_name: "Tennis racket"
}
item {
name: "/m/03ldnb"
id: 254
display_name: "Ceiling fan"
}
item {
name: "/m/0cjs7"
id: 255
display_name: "Asparagus"
}
item {
name: "/m/0449p"
id: 256
display_name: "Jaguar"
}
item {
name: "/m/04szw"
id: 257
display_name: "Musical instrument"
}
item {
name: "/m/07jdr"
id: 258
display_name: "Train"
}
item {
name: "/m/01yrx"
id: 259
display_name: "Cat"
}
item {
name: "/m/06c54"
id: 260
display_name: "Rifle"
}
item {
name: "/m/04h8sr"
id: 261
display_name: "Dumbbell"
}
item {
name: "/m/050k8"
id: 262
display_name: "Mobile phone"
}
item {
name: "/m/0pg52"
id: 263
display_name: "Taxi"
}
item {
name: "/m/02f9f_"
id: 264
display_name: "Shower"
}
item {
name: "/m/054fyh"
id: 265
display_name: "Pitcher"
}
item {
name: "/m/09k_b"
id: 266
display_name: "Lemon"
}
item {
name: "/m/03xxp"
id: 267
display_name: "Invertebrate"
}
item {
name: "/m/0jly1"
id: 268
display_name: "Turkey"
}
item {
name: "/m/06k2mb"
id: 269
display_name: "High heels"
}
item {
name: "/m/04yqq2"
id: 270
display_name: "Bust"
}
item {
name: "/m/0bwd_0j"
id: 271
display_name: "Elephant"
}
item {
name: "/m/02h19r"
id: 272
display_name: "Scarf"
}
item {
name: "/m/02zn6n"
id: 273
display_name: "Barrel"
}
item {
name: "/m/07c6l"
id: 274
display_name: "Trombone"
}
item {
name: "/m/05zsy"
id: 275
display_name: "Pumpkin"
}
item {
name: "/m/025dyy"
id: 276
display_name: "Box"
}
item {
name: "/m/07j87"
id: 277
display_name: "Tomato"
}
item {
name: "/m/09ld4"
id: 278
display_name: "Frog"
}
item {
name: "/m/01vbnl"
id: 279
display_name: "Bidet"
}
item {
name: "/m/0dzct"
id: 280
display_name: "Human face"
}
item {
name: "/m/03fp41"
id: 281
display_name: "Houseplant"
}
item {
name: "/m/0h2r6"
id: 282
display_name: "Van"
}
item {
name: "/m/0by6g"
id: 283
display_name: "Shark"
}
item {
name: "/m/0cxn2"
id: 284
display_name: "Ice cream"
}
item {
name: "/m/04tn4x"
id: 285
display_name: "Swim cap"
}
item {
name: "/m/0f6wt"
id: 286
display_name: "Falcon"
}
item {
name: "/m/05n4y"
id: 287
display_name: "Ostrich"
}
item {
name: "/m/0gxl3"
id: 288
display_name: "Handgun"
}
item {
name: "/m/02d9qx"
id: 289
display_name: "Whiteboard"
}
item {
name: "/m/04m9y"
id: 290
display_name: "Lizard"
}
item {
name: "/m/05z55"
id: 291
display_name: "Pasta"
}
item {
name: "/m/01x3jk"
id: 292
display_name: "Snowmobile"
}
item {
name: "/m/0h8l4fh"
id: 293
display_name: "Light bulb"
}
item {
name: "/m/031b6r"
id: 294
display_name: "Window blind"
}
item {
name: "/m/01tcjp"
id: 295
display_name: "Muffin"
}
item {
name: "/m/01f91_"
id: 296
display_name: "Pretzel"
}
item {
name: "/m/02522"
id: 297
display_name: "Computer monitor"
}
item {
name: "/m/0319l"
id: 298
display_name: "Horn"
}
item {
name: "/m/0c_jw"
id: 299
display_name: "Furniture"
}
item {
name: "/m/0l515"
id: 300
display_name: "Sandwich"
}
item {
name: "/m/0306r"
id: 301
display_name: "Fox"
}
item {
name: "/m/0crjs"
id: 302
display_name: "Convenience store"
}
item {
name: "/m/0ch_cf"
id: 303
display_name: "Fish"
}
item {
name: "/m/02xwb"
id: 304
display_name: "Fruit"
}
item {
name: "/m/01r546"
id: 305
display_name: "Earrings"
}
item {
name: "/m/03rszm"
id: 306
display_name: "Curtain"
}
item {
name: "/m/0388q"
id: 307
display_name: "Grape"
}
item {
name: "/m/03m3pdh"
id: 308
display_name: "Sofa bed"
}
item {
name: "/m/03k3r"
id: 309
display_name: "Horse"
}
item {
name: "/m/0hf58v5"
id: 310
display_name: "Luggage and bags"
}
item {
name: "/m/01y9k5"
id: 311
display_name: "Desk"
}
item {
name: "/m/05441v"
id: 312
display_name: "Crutch"
}
item {
name: "/m/03p3bw"
id: 313
display_name: "Bicycle helmet"
}
item {
name: "/m/0175cv"
id: 314
display_name: "Tick"
}
item {
name: "/m/0cmf2"
id: 315
display_name: "Airplane"
}
item {
name: "/m/0ccs93"
id: 316
display_name: "Canary"
}
item {
name: "/m/02d1br"
id: 317
display_name: "Spatula"
}
item {
name: "/m/0gjkl"
id: 318
display_name: "Watch"
}
item {
name: "/m/0jqgx"
id: 319
display_name: "Lily"
}
item {
name: "/m/0h99cwc"
id: 320
display_name: "Kitchen appliance"
}
item {
name: "/m/047j0r"
id: 321
display_name: "Filing cabinet"
}
item {
name: "/m/0k5j"
id: 322
display_name: "Aircraft"
}
item {
name: "/m/0h8n6ft"
id: 323
display_name: "Cake stand"
}
item {
name: "/m/0gm28"
id: 324
display_name: "Candy"
}
item {
name: "/m/0130jx"
id: 325
display_name: "Sink"
}
item {
name: "/m/04rmv"
id: 326
display_name: "Mouse"
}
item {
name: "/m/081qc"
id: 327
display_name: "Wine"
}
item {
name: "/m/0qmmr"
id: 328
display_name: "Wheelchair"
}
item {
name: "/m/03fj2"
id: 329
display_name: "Goldfish"
}
item {
name: "/m/040b_t"
id: 330
display_name: "Refrigerator"
}
item {
name: "/m/02y6n"
id: 331
display_name: "French fries"
}
item {
name: "/m/0fqfqc"
id: 332
display_name: "Drawer"
}
item {
name: "/m/030610"
id: 333
display_name: "Treadmill"
}
item {
name: "/m/07kng9"
id: 334
display_name: "Picnic basket"
}
item {
name: "/m/029b3"
id: 335
display_name: "Dice"
}
item {
name: "/m/0fbw6"
id: 336
display_name: "Cabbage"
}
item {
name: "/m/07qxg_"
id: 337
display_name: "Football helmet"
}
item {
name: "/m/068zj"
id: 338
display_name: "Pig"
}
item {
name: "/m/01g317"
id: 339
display_name: "Person"
}
item {
name: "/m/01bfm9"
id: 340
display_name: "Shorts"
}
item {
name: "/m/02068x"
id: 341
display_name: "Gondola"
}
item {
name: "/m/0fz0h"
id: 342
display_name: "Honeycomb"
}
item {
name: "/m/0jy4k"
id: 343
display_name: "Doughnut"
}
item {
name: "/m/05kyg_"
id: 344
display_name: "Chest of drawers"
}
item {
name: "/m/01prls"
id: 345
display_name: "Land vehicle"
}
item {
name: "/m/01h44"
id: 346
display_name: "Bat"
}
item {
name: "/m/08pbxl"
id: 347
display_name: "Monkey"
}
item {
name: "/m/02gzp"
id: 348
display_name: "Dagger"
}
item {
name: "/m/04brg2"
id: 349
display_name: "Tableware"
}
item {
name: "/m/031n1"
id: 350
display_name: "Human foot"
}
item {
name: "/m/02jvh9"
id: 351
display_name: "Mug"
}
item {
name: "/m/046dlr"
id: 352
display_name: "Alarm clock"
}
item {
name: "/m/0h8ntjv"
id: 353
display_name: "Pressure cooker"
}
item {
name: "/m/0k65p"
id: 354
display_name: "Human hand"
}
item {
name: "/m/011k07"
id: 355
display_name: "Tortoise"
}
item {
name: "/m/03grzl"
id: 356
display_name: "Baseball glove"
}
item {
name: "/m/06y5r"
id: 357
display_name: "Sword"
}
item {
name: "/m/061_f"
id: 358
display_name: "Pear"
}
item {
name: "/m/01cmb2"
id: 359
display_name: "Miniskirt"
}
item {
name: "/m/01mqdt"
id: 360
display_name: "Traffic sign"
}
item {
name: "/m/05r655"
id: 361
display_name: "Girl"
}
item {
name: "/m/02p3w7d"
id: 362
display_name: "Roller skates"
}
item {
name: "/m/029tx"
id: 363
display_name: "Dinosaur"
}
item {
name: "/m/04m6gz"
id: 364
display_name: "Porch"
}
item {
name: "/m/015h_t"
id: 365
display_name: "Human beard"
}
item {
name: "/m/06pcq"
id: 366
display_name: "Submarine sandwich"
}
item {
name: "/m/01bms0"
id: 367
display_name: "Screwdriver"
}
item {
name: "/m/07fbm7"
id: 368
display_name: "Strawberry"
}
item {
name: "/m/09tvcd"
id: 369
display_name: "Wine glass"
}
item {
name: "/m/06nwz"
id: 370
display_name: "Seafood"
}
item {
name: "/m/0dv9c"
id: 371
display_name: "Racket"
}
item {
name: "/m/083wq"
id: 372
display_name: "Wheel"
}
item {
name: "/m/0gd36"
id: 373
display_name: "Sea lion"
}
item {
name: "/m/0138tl"
id: 374
display_name: "Toy"
}
item {
name: "/m/07clx"
id: 375
display_name: "Tea"
}
item {
name: "/m/05ctyq"
id: 376
display_name: "Tennis ball"
}
item {
name: "/m/0bjyj5"
id: 377
display_name: "Waste container"
}
item {
name: "/m/0dbzx"
id: 378
display_name: "Mule"
}
item {
name: "/m/02ctlc"
id: 379
display_name: "Cricket ball"
}
item {
name: "/m/0fp6w"
id: 380
display_name: "Pineapple"
}
item {
name: "/m/0djtd"
id: 381
display_name: "Coconut"
}
item {
name: "/m/0167gd"
id: 382
display_name: "Doll"
}
item {
name: "/m/078n6m"
id: 383
display_name: "Coffee table"
}
item {
name: "/m/0152hh"
id: 384
display_name: "Snowman"
}
item {
name: "/m/04gth"
id: 385
display_name: "Lavender"
}
item {
name: "/m/0ll1f78"
id: 386
display_name: "Shrimp"
}
item {
name: "/m/0cffdh"
id: 387
display_name: "Maple"
}
item {
name: "/m/025rp__"
id: 388
display_name: "Cowboy hat"
}
item {
name: "/m/02_n6y"
id: 389
display_name: "Goggles"
}
item {
name: "/m/0wdt60w"
id: 390
display_name: "Rugby ball"
}
item {
name: "/m/0cydv"
id: 391
display_name: "Caterpillar"
}
item {
name: "/m/01n5jq"
id: 392
display_name: "Poster"
}
item {
name: "/m/09rvcxw"
id: 393
display_name: "Rocket"
}
item {
name: "/m/013y1f"
id: 394
display_name: "Organ"
}
item {
name: "/m/06ncr"
id: 395
display_name: "Saxophone"
}
item {
name: "/m/015qff"
id: 396
display_name: "Traffic light"
}
item {
name: "/m/024g6"
id: 397
display_name: "Cocktail"
}
item {
name: "/m/05gqfk"
id: 398
display_name: "Plastic bag"
}
item {
name: "/m/0dv77"
id: 399
display_name: "Squash"
}
item {
name: "/m/052sf"
id: 400
display_name: "Mushroom"
}
item {
name: "/m/0cdn1"
id: 401
display_name: "Hamburger"
}
item {
name: "/m/03jbxj"
id: 402
display_name: "Light switch"
}
item {
name: "/m/0cyfs"
id: 403
display_name: "Parachute"
}
item {
name: "/m/0kmg4"
id: 404
display_name: "Teddy bear"
}
item {
name: "/m/02cvgx"
id: 405
display_name: "Winter melon"
}
item {
name: "/m/09kx5"
id: 406
display_name: "Deer"
}
item {
name: "/m/057cc"
id: 407
display_name: "Musical keyboard"
}
item {
name: "/m/02pkr5"
id: 408
display_name: "Plumbing fixture"
}
item {
name: "/m/057p5t"
id: 409
display_name: "Scoreboard"
}
item {
name: "/m/03g8mr"
id: 410
display_name: "Baseball bat"
}
item {
name: "/m/0frqm"
id: 411
display_name: "Envelope"
}
item {
name: "/m/03m3vtv"
id: 412
display_name: "Adhesive tape"
}
item {
name: "/m/0584n8"
id: 413
display_name: "Briefcase"
}
item {
name: "/m/014y4n"
id: 414
display_name: "Paddle"
}
item {
name: "/m/01g3x7"
id: 415
display_name: "Bow and arrow"
}
item {
name: "/m/07cx4"
id: 416
display_name: "Telephone"
}
item {
name: "/m/07bgp"
id: 417
display_name: "Sheep"
}
item {
name: "/m/032b3c"
id: 418
display_name: "Jacket"
}
item {
name: "/m/01bl7v"
id: 419
display_name: "Boy"
}
item {
name: "/m/0663v"
id: 420
display_name: "Pizza"
}
item {
name: "/m/0cn6p"
id: 421
display_name: "Otter"
}
item {
name: "/m/02rdsp"
id: 422
display_name: "Office supplies"
}
item {
name: "/m/02crq1"
id: 423
display_name: "Couch"
}
item {
name: "/m/01xqw"
id: 424
display_name: "Cello"
}
item {
name: "/m/0cnyhnx"
id: 425
display_name: "Bull"
}
item {
name: "/m/01x_v"
id: 426
display_name: "Camel"
}
item {
name: "/m/018xm"
id: 427
display_name: "Ball"
}
item {
name: "/m/09ddx"
id: 428
display_name: "Duck"
}
item {
name: "/m/084zz"
id: 429
display_name: "Whale"
}
item {
name: "/m/01n4qj"
id: 430
display_name: "Shirt"
}
item {
name: "/m/07cmd"
id: 431
display_name: "Tank"
}
item {
name: "/m/04_sv"
id: 432
display_name: "Motorcycle"
}
item {
name: "/m/0mkg"
id: 433
display_name: "Accordion"
}
item {
name: "/m/09d5_"
id: 434
display_name: "Owl"
}
item {
name: "/m/0c568"
id: 435
display_name: "Porcupine"
}
item {
name: "/m/02wbtzl"
id: 436
display_name: "Sun hat"
}
item {
name: "/m/05bm6"
id: 437
display_name: "Nail"
}
item {
name: "/m/01lsmm"
id: 438
display_name: "Scissors"
}
item {
name: "/m/0dftk"
id: 439
display_name: "Swan"
}
item {
name: "/m/0dtln"
id: 440
display_name: "Lamp"
}
item {
name: "/m/0nl46"
id: 441
display_name: "Crown"
}
item {
name: "/m/05r5c"
id: 442
display_name: "Piano"
}
item {
name: "/m/06msq"
id: 443
display_name: "Sculpture"
}
item {
name: "/m/0cd4d"
id: 444
display_name: "Cheetah"
}
item {
name: "/m/05kms"
id: 445
display_name: "Oboe"
}
item {
name: "/m/02jnhm"
id: 446
display_name: "Tin can"
}
item {
name: "/m/0fldg"
id: 447
display_name: "Mango"
}
item {
name: "/m/073bxn"
id: 448
display_name: "Tripod"
}
item {
name: "/m/029bxz"
id: 449
display_name: "Oven"
}
item {
name: "/m/020lf"
id: 450
display_name: "Computer mouse"
}
item {
name: "/m/01btn"
id: 451
display_name: "Barge"
}
item {
name: "/m/02vqfm"
id: 452
display_name: "Coffee"
}
item {
name: "/m/06__v"
id: 453
display_name: "Snowboard"
}
item {
name: "/m/043nyj"
id: 454
display_name: "Common fig"
}
item {
name: "/m/0grw1"
id: 455
display_name: "Salad"
}
item {
name: "/m/03hl4l9"
id: 456
display_name: "Marine invertebrates"
}
item {
name: "/m/0hnnb"
id: 457
display_name: "Umbrella"
}
item {
name: "/m/04c0y"
id: 458
display_name: "Kangaroo"
}
item {
name: "/m/0dzf4"
id: 459
display_name: "Human arm"
}
item {
name: "/m/07v9_z"
id: 460
display_name: "Measuring cup"
}
item {
name: "/m/0f9_l"
id: 461
display_name: "Snail"
}
item {
name: "/m/0703r8"
id: 462
display_name: "Loveseat"
}
item {
name: "/m/01xyhv"
id: 463
display_name: "Suit"
}
item {
name: "/m/01fh4r"
id: 464
display_name: "Teapot"
}
item {
name: "/m/04dr76w"
id: 465
display_name: "Bottle"
}
item {
name: "/m/0pcr"
id: 466
display_name: "Alpaca"
}
item {
name: "/m/03s_tn"
id: 467
display_name: "Kettle"
}
item {
name: "/m/07mhn"
id: 468
display_name: "Trousers"
}
item {
name: "/m/01hrv5"
id: 469
display_name: "Popcorn"
}
item {
name: "/m/019h78"
id: 470
display_name: "Centipede"
}
item {
name: "/m/09kmb"
id: 471
display_name: "Spider"
}
item {
name: "/m/0h23m"
id: 472
display_name: "Sparrow"
}
item {
name: "/m/050gv4"
id: 473
display_name: "Plate"
}
item {
name: "/m/01fb_0"
id: 474
display_name: "Bagel"
}
item {
name: "/m/02w3_ws"
id: 475
display_name: "Personal care"
}
item {
name: "/m/014j1m"
id: 476
display_name: "Apple"
}
item {
name: "/m/01gmv2"
id: 477
display_name: "Brassiere"
}
item {
name: "/m/04y4h8h"
id: 478
display_name: "Bathroom cabinet"
}
item {
name: "/m/026qbn5"
id: 479
display_name: "Studio couch"
}
item {
name: "/m/01m2v"
id: 480
display_name: "Computer keyboard"
}
item {
name: "/m/05_5p_0"
id: 481
display_name: "Table tennis racket"
}
item {
name: "/m/07030"
id: 482
display_name: "Sushi"
}
item {
name: "/m/01s105"
id: 483
display_name: "Cabinetry"
}
item {
name: "/m/033rq4"
id: 484
display_name: "Street light"
}
item {
name: "/m/0162_1"
id: 485
display_name: "Towel"
}
item {
name: "/m/02z51p"
id: 486
display_name: "Nightstand"
}
item {
name: "/m/06mf6"
id: 487
display_name: "Rabbit"
}
item {
name: "/m/02hj4"
id: 488
display_name: "Dolphin"
}
item {
name: "/m/0bt9lr"
id: 489
display_name: "Dog"
}
item {
name: "/m/08hvt4"
id: 490
display_name: "Jug"
}
item {
name: "/m/084rd"
id: 491
display_name: "Wok"
}
item {
name: "/m/01pns0"
id: 492
display_name: "Fire hydrant"
}
item {
name: "/m/014sv8"
id: 493
display_name: "Human eye"
}
item {
name: "/m/079cl"
id: 494
display_name: "Skyscraper"
}
item {
name: "/m/01940j"
id: 495
display_name: "Backpack"
}
item {
name: "/m/05vtc"
id: 496
display_name: "Potato"
}
item {
name: "/m/02w3r3"
id: 497
display_name: "Paper towel"
}
item {
name: "/m/054xkw"
id: 498
display_name: "Lifejacket"
}
item {
name: "/m/01bqk0"
id: 499
display_name: "Bicycle wheel"
}
item {
name: "/m/09g1w"
id: 500
display_name: "Toilet"
}
......@@ -112,7 +112,8 @@ class TfExampleDecoder(data_decoder.DataDecoder):
label_map_proto_file=None,
use_display_name=False,
dct_method='',
num_keypoints=0):
num_keypoints=0,
num_additional_channels=0):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
......@@ -133,6 +134,7 @@ class TfExampleDecoder(data_decoder.DataDecoder):
are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
example, the jpeg library does not have that specific option.
num_keypoints: the number of keypoints per object.
num_additional_channels: how many additional channels to use.
Raises:
ValueError: If `instance_mask_type` option is not one of
......@@ -178,15 +180,28 @@ class TfExampleDecoder(data_decoder.DataDecoder):
'image/object/weight':
tf.VarLenFeature(tf.float32),
}
# We are checking `dct_method` instead of passing it directly in order to
# ensure TF version 1.6 compatibility.
if dct_method:
image = slim_example_decoder.Image(
image_key='image/encoded',
format_key='image/format',
channels=3,
dct_method=dct_method)
additional_channel_image = slim_example_decoder.Image(
image_key='image/additional_channels/encoded',
format_key='image/format',
channels=1,
repeated=True,
dct_method=dct_method)
else:
image = slim_example_decoder.Image(
image_key='image/encoded', format_key='image/format', channels=3)
additional_channel_image = slim_example_decoder.Image(
image_key='image/additional_channels/encoded',
format_key='image/format',
channels=1,
repeated=True)
self.items_to_handlers = {
fields.InputDataFields.image:
image,
......@@ -211,6 +226,13 @@ class TfExampleDecoder(data_decoder.DataDecoder):
fields.InputDataFields.groundtruth_weights: (
slim_example_decoder.Tensor('image/object/weight')),
}
if num_additional_channels > 0:
self.keys_to_features[
'image/additional_channels/encoded'] = tf.FixedLenFeature(
(num_additional_channels,), tf.string)
self.items_to_handlers[
fields.InputDataFields.
image_additional_channels] = additional_channel_image
self._num_keypoints = num_keypoints
if num_keypoints > 0:
self.keys_to_features['image/object/keypoint/x'] = (
......@@ -294,6 +316,9 @@ class TfExampleDecoder(data_decoder.DataDecoder):
[None] indicating if the boxes enclose a crowd.
Optional:
fields.InputDataFields.image_additional_channels - 3D uint8 tensor of
shape [None, None, num_additional_channels]. 1st dim is height; 2nd dim
is width; 3rd dim is the number of additional channels.
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
[None] indicating if the boxes represent `difficult` instances.
fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
......@@ -316,6 +341,12 @@ class TfExampleDecoder(data_decoder.DataDecoder):
tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]
if fields.InputDataFields.image_additional_channels in tensor_dict:
channels = tensor_dict[fields.InputDataFields.image_additional_channels]
channels = tf.squeeze(channels, axis=3)
channels = tf.transpose(channels, perm=[1, 2, 0])
tensor_dict[fields.InputDataFields.image_additional_channels] = channels
def default_groundtruth_weights():
return tf.ones(
[tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]],
......
......@@ -23,6 +23,7 @@ from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
......@@ -72,10 +73,41 @@ class TfExampleDecoderTest(tf.test.TestCase):
def _BytesFeatureFromList(self, ndarray):
values = ndarray.flatten().tolist()
for i in range(len(values)):
values[i] = values[i].encode('utf-8')
return feature_pb2.Feature(bytes_list=feature_pb2.BytesList(value=values))
def testDecodeAdditionalChannels(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
additional_channel_tensor = np.random.randint(
256, size=(4, 5, 1)).astype(np.uint8)
encoded_additional_channel = self._EncodeImage(additional_channel_tensor)
decoded_additional_channel = self._DecodeImage(encoded_additional_channel)
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
self._BytesFeature(encoded_jpeg),
'image/additional_channels/encoded':
self._BytesFeatureFromList(
np.array([encoded_additional_channel] * 2)),
'image/format':
self._BytesFeature('jpeg'),
'image/source_id':
self._BytesFeature('image_id'),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
num_additional_channels=2)
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual(
np.concatenate([decoded_additional_channel] * 2, axis=2),
tensor_dict[fields.InputDataFields.image_additional_channels])
def testDecodeExampleWithBranchedBackupHandler(self):
example1 = example_pb2.Example(
features=feature_pb2.Features(
......@@ -304,6 +336,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual(
2, tensor_dict[fields.InputDataFields.num_groundtruth_boxes])
@test_util.enable_c_shapes
def testDecodeKeypoint(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
......@@ -331,7 +364,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
get_shape().as_list()), [None, 4])
self.assertAllEqual((tensor_dict[fields.InputDataFields.
groundtruth_keypoints].
get_shape().as_list()), [None, 3, 2])
get_shape().as_list()), [2, 3, 2])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
......@@ -376,6 +409,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllClose(tensor_dict[fields.InputDataFields.groundtruth_weights],
np.ones(2, dtype=np.float32))
@test_util.enable_c_shapes
def testDecodeObjectLabel(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
......@@ -391,7 +425,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual((tensor_dict[
fields.InputDataFields.groundtruth_classes].get_shape().as_list()),
[None])
[2])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
......@@ -522,6 +556,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual([3, 1],
tensor_dict[fields.InputDataFields.groundtruth_classes])
@test_util.enable_c_shapes
def testDecodeObjectArea(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
......@@ -536,13 +571,14 @@ class TfExampleDecoderTest(tf.test.TestCase):
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_area].
get_shape().as_list()), [None])
get_shape().as_list()), [2])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual(object_area,
tensor_dict[fields.InputDataFields.groundtruth_area])
@test_util.enable_c_shapes
def testDecodeObjectIsCrowd(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
......@@ -558,7 +594,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual((tensor_dict[
fields.InputDataFields.groundtruth_is_crowd].get_shape().as_list()),
[None])
[2])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
......@@ -566,6 +602,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
tensor_dict[
fields.InputDataFields.groundtruth_is_crowd])
@test_util.enable_c_shapes
def testDecodeObjectDifficult(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
......@@ -581,7 +618,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual((tensor_dict[
fields.InputDataFields.groundtruth_difficult].get_shape().as_list()),
[None])
[2])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
......@@ -589,6 +626,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
tensor_dict[
fields.InputDataFields.groundtruth_difficult])
@test_util.enable_c_shapes
def testDecodeObjectGroupOf(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
......@@ -605,7 +643,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual((tensor_dict[
fields.InputDataFields.groundtruth_group_of].get_shape().as_list()),
[None])
[2])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
......@@ -637,6 +675,7 @@ class TfExampleDecoderTest(tf.test.TestCase):
object_weights,
tensor_dict[fields.InputDataFields.groundtruth_weights])
@test_util.enable_c_shapes
def testDecodeInstanceSegmentation(self):
num_instances = 4
image_height = 5
......@@ -673,11 +712,11 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual((
tensor_dict[fields.InputDataFields.groundtruth_instance_masks].
get_shape().as_list()), [None, None, None])
get_shape().as_list()), [4, 5, 3])
self.assertAllEqual((
tensor_dict[fields.InputDataFields.groundtruth_classes].
get_shape().as_list()), [None])
get_shape().as_list()), [4])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
......
......@@ -16,7 +16,8 @@ r"""Creates TFRecords of Open Images dataset for object detection.
Example usage:
python object_detection/dataset_tools/create_oid_tf_record.py \
--input_annotations_csv=/path/to/input/annotations-human-bbox.csv \
--input_box_annotations_csv=/path/to/input/annotations-human-bbox.csv \
--input_image_label_annotations_csv=/path/to/input/annotations-label.csv \
--input_images_directory=/path/to/input/image_pixels_directory \
--input_label_map=/path/to/input/labels_bbox_545.labelmap \
--output_tf_record_path_prefix=/path/to/output/prefix.tfrecord
......@@ -27,7 +28,9 @@ https://github.com/openimages/dataset
This script will include every image found in the input_images_directory in the
output TFRecord, even if the image has no corresponding bounding box annotations
in the input_annotations_csv.
in the input_annotations_csv. If input_image_label_annotations_csv is specified,
it will add image-level labels as well. Note that the information of whether a
label is positivelly or negativelly verified is NOT added to tfrecord.
"""
from __future__ import absolute_import
from __future__ import division
......@@ -40,13 +43,16 @@ import pandas as pd
import tensorflow as tf
from object_detection.dataset_tools import oid_tfrecord_creation
from object_detection.dataset_tools import tf_record_creation_util
from object_detection.utils import label_map_util
tf.flags.DEFINE_string('input_annotations_csv', None,
tf.flags.DEFINE_string('input_box_annotations_csv', None,
'Path to CSV containing image bounding box annotations')
tf.flags.DEFINE_string('input_images_directory', None,
'Directory containing the image pixels '
'downloaded from the OpenImages GitHub repository.')
tf.flags.DEFINE_string('input_image_label_annotations_csv', None,
'Path to CSV containing image-level labels annotations')
tf.flags.DEFINE_string('input_label_map', None, 'Path to the label map proto')
tf.flags.DEFINE_string(
'output_tf_record_path_prefix', None,
......@@ -61,7 +67,7 @@ def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
required_flags = [
'input_annotations_csv', 'input_images_directory', 'input_label_map',
'input_box_annotations_csv', 'input_images_directory', 'input_label_map',
'output_tf_record_path_prefix'
]
for flag_name in required_flags:
......@@ -69,17 +75,24 @@ def main(_):
raise ValueError('Flag --{} is required'.format(flag_name))
label_map = label_map_util.get_label_map_dict(FLAGS.input_label_map)
all_annotations = pd.read_csv(FLAGS.input_annotations_csv)
all_box_annotations = pd.read_csv(FLAGS.input_box_annotations_csv)
if FLAGS.input_image_label_annotations_csv:
all_label_annotations = pd.read_csv(FLAGS.input_image_label_annotations_csv)
all_label_annotations.rename(
columns={'Confidence': 'ConfidenceImageLabel'}, inplace=True)
else:
all_label_annotations = None
all_images = tf.gfile.Glob(
os.path.join(FLAGS.input_images_directory, '*.jpg'))
all_image_ids = [os.path.splitext(os.path.basename(v))[0] for v in all_images]
all_image_ids = pd.DataFrame({'ImageID': all_image_ids})
all_annotations = pd.concat([all_annotations, all_image_ids])
all_annotations = pd.concat(
[all_box_annotations, all_image_ids, all_label_annotations])
tf.logging.log(tf.logging.INFO, 'Found %d images...', len(all_image_ids))
with contextlib2.ExitStack() as tf_record_close_stack:
output_tfrecords = oid_tfrecord_creation.open_sharded_output_tfrecords(
output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
tf_record_close_stack, FLAGS.output_tf_record_path_prefix,
FLAGS.num_shards)
......
......@@ -33,11 +33,13 @@ import os
import random
import re
import contextlib2
from lxml import etree
import numpy as np
import PIL.Image
import tensorflow as tf
from object_detection.dataset_tools import tf_record_creation_util
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
......@@ -52,6 +54,8 @@ flags.DEFINE_boolean('faces_only', True, 'If True, generates bounding boxes '
'in the latter case, the resulting files are much larger.')
flags.DEFINE_string('mask_type', 'png', 'How to represent instance '
'segmentation masks. Options are "png" or "numerical".')
flags.DEFINE_integer('num_shards', 10, 'Number of TFRecord shards')
FLAGS = flags.FLAGS
......@@ -208,6 +212,7 @@ def dict_to_tf_example(data,
def create_tf_record(output_filename,
num_shards,
label_map_dict,
annotations_dir,
image_dir,
......@@ -218,6 +223,7 @@ def create_tf_record(output_filename,
Args:
output_filename: Path to where output file is saved.
num_shards: Number of shards for output file.
label_map_dict: The label map dictionary.
annotations_dir: Directory where annotation files are stored.
image_dir: Directory where image files are stored.
......@@ -227,34 +233,36 @@ def create_tf_record(output_filename,
mask_type: 'numerical' or 'png'. 'png' is recommended because it leads to
smaller file sizes.
"""
writer = tf.python_io.TFRecordWriter(output_filename)
for idx, example in enumerate(examples):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples))
xml_path = os.path.join(annotations_dir, 'xmls', example + '.xml')
mask_path = os.path.join(annotations_dir, 'trimaps', example + '.png')
if not os.path.exists(xml_path):
logging.warning('Could not find %s, ignoring example.', xml_path)
continue
with tf.gfile.GFile(xml_path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
try:
tf_example = dict_to_tf_example(
data,
mask_path,
label_map_dict,
image_dir,
faces_only=faces_only,
mask_type=mask_type)
writer.write(tf_example.SerializeToString())
except ValueError:
logging.warning('Invalid example: %s, ignoring.', xml_path)
writer.close()
with contextlib2.ExitStack() as tf_record_close_stack:
output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
tf_record_close_stack, output_filename, num_shards)
for idx, example in enumerate(examples):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples))
xml_path = os.path.join(annotations_dir, 'xmls', example + '.xml')
mask_path = os.path.join(annotations_dir, 'trimaps', example + '.png')
if not os.path.exists(xml_path):
logging.warning('Could not find %s, ignoring example.', xml_path)
continue
with tf.gfile.GFile(xml_path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
try:
tf_example = dict_to_tf_example(
data,
mask_path,
label_map_dict,
image_dir,
faces_only=faces_only,
mask_type=mask_type)
if tf_example:
shard_idx = idx % num_shards
output_tfrecords[shard_idx].write(tf_example.SerializeToString())
except ValueError:
logging.warning('Invalid example: %s, ignoring.', xml_path)
# TODO(derekjchow): Add test for pet/PASCAL main files.
......@@ -279,15 +287,16 @@ def main(_):
logging.info('%d training and %d validation examples.',
len(train_examples), len(val_examples))
train_output_path = os.path.join(FLAGS.output_dir, 'pet_train.record')
val_output_path = os.path.join(FLAGS.output_dir, 'pet_val.record')
if FLAGS.faces_only:
train_output_path = os.path.join(FLAGS.output_dir, 'pet_faces_train.record')
val_output_path = os.path.join(FLAGS.output_dir, 'pet_faces_val.record')
if not FLAGS.faces_only:
train_output_path = os.path.join(FLAGS.output_dir,
'pet_train_with_masks.record')
'pets_fullbody_with_masks_train.record')
val_output_path = os.path.join(FLAGS.output_dir,
'pet_val_with_masks.record')
'pets_fullbody_with_masks_val.record')
create_tf_record(
train_output_path,
FLAGS.num_shards,
label_map_dict,
annotations_dir,
image_dir,
......@@ -296,6 +305,7 @@ def main(_):
mask_type=FLAGS.mask_type)
create_tf_record(
val_output_path,
FLAGS.num_shards,
label_map_dict,
annotations_dir,
image_dir,
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A class and executable to expand hierarchically image-level labels and boxes.
Example usage:
./hierarchical_labels_expansion <path to JSON hierarchy> <input csv file>
<output csv file> [optional]labels_file
"""
import json
import sys
def _update_dict(initial_dict, update):
"""Updates dictionary with update content.
Args:
initial_dict: initial dictionary.
update: updated dictionary.
"""
for key, value_list in update.iteritems():
if key in initial_dict:
initial_dict[key].extend(value_list)
else:
initial_dict[key] = value_list
def _build_plain_hierarchy(hierarchy, skip_root=False):
"""Expands tree hierarchy representation to parent-child dictionary.
Args:
hierarchy: labels hierarchy as JSON file.
skip_root: if true skips root from the processing (done for the case when all
classes under hierarchy are collected under virtual node).
Returns:
keyed_parent - dictionary of parent - all its children nodes.
keyed_child - dictionary of children - all its parent nodes
children - all children of the current node.
"""
all_children = []
all_keyed_parent = {}
all_keyed_child = {}
if 'Subcategory' in hierarchy:
for node in hierarchy['Subcategory']:
keyed_parent, keyed_child, children = _build_plain_hierarchy(node)
# Update is not done through dict.update() since some children have multi-
# ple parents in the hiearchy.
_update_dict(all_keyed_parent, keyed_parent)
_update_dict(all_keyed_child, keyed_child)
all_children.extend(children)
if not skip_root:
all_keyed_parent[hierarchy['LabelName']] = all_children
all_children = [hierarchy['LabelName']] + all_children
for child, _ in all_keyed_child.iteritems():
all_keyed_child[child].append(hierarchy['LabelName'])
all_keyed_child[hierarchy['LabelName']] = []
return all_keyed_parent, all_keyed_child, all_children
class OIDHierarchicalLabelsExpansion(object):
""" Main class to perform labels hierachical expansion."""
def __init__(self, hierarchy):
"""Constructor.
Args:
hierarchy: labels hierarchy as JSON file.
"""
self._hierarchy_keyed_parent, self._hierarchy_keyed_child, _ = (
_build_plain_hierarchy(hierarchy, skip_root=True))
def expand_boxes_from_csv(self, csv_row):
"""Expands a row containing bounding boxes from CSV file.
Args:
csv_row: a single row of Open Images released groundtruth file.
Returns:
a list of strings (including the initial row) corresponding to the ground
truth expanded to multiple annotation for evaluation with Open Images
Challenge 2018 metric.
"""
# Row header is expected to be exactly:
# ImageID,Source,LabelName,Confidence,XMin,XMax,YMin,YMax,IsOccluded,
# IsTruncated,IsGroupOf,IsDepiction,IsInside
cvs_row_splited = csv_row.split(',')
assert len(cvs_row_splited) == 13
result = [csv_row]
assert cvs_row_splited[2] in self._hierarchy_keyed_child
parent_nodes = self._hierarchy_keyed_child[cvs_row_splited[2]]
for parent_node in parent_nodes:
cvs_row_splited[2] = parent_node
result.append(','.join(cvs_row_splited))
return result
def expand_labels_from_csv(self, csv_row):
"""Expands a row containing bounding boxes from CSV file.
Args:
csv_row: a single row of Open Images released groundtruth file.
Returns:
a list of strings (including the initial row) corresponding to the ground
truth expanded to multiple annotation for evaluation with Open Images
Challenge 2018 metric.
"""
# Row header is expected to be exactly:
# ImageID,Source,LabelName,Confidence
cvs_row_splited = csv_row.split(',')
assert len(cvs_row_splited) == 4
result = [csv_row]
if int(cvs_row_splited[3]) == 1:
assert cvs_row_splited[2] in self._hierarchy_keyed_child
parent_nodes = self._hierarchy_keyed_child[cvs_row_splited[2]]
for parent_node in parent_nodes:
cvs_row_splited[2] = parent_node
result.append(','.join(cvs_row_splited))
else:
assert cvs_row_splited[2] in self._hierarchy_keyed_parent
child_nodes = self._hierarchy_keyed_parent[cvs_row_splited[2]]
for child_node in child_nodes:
cvs_row_splited[2] = child_node
result.append(','.join(cvs_row_splited))
return result
def main(argv):
if len(argv) < 4:
print """Missing arguments. \n
Usage: ./hierarchical_labels_expansion <path to JSON hierarchy>
<input csv file> <output csv file> [optional]labels_file"""
return
with open(argv[1]) as f:
hierarchy = json.load(f)
expansion_generator = OIDHierarchicalLabelsExpansion(hierarchy)
labels_file = False
if len(argv) > 4 and argv[4] == 'labels_file':
labels_file = True
with open(argv[2], 'r') as source:
with open(argv[3], 'w') as target:
header_skipped = False
for line in source:
if not header_skipped:
header_skipped = True
continue
if labels_file:
expanded_lines = expansion_generator.expand_labels_from_csv(line)
else:
expanded_lines = expansion_generator.expand_boxes_from_csv(line)
target.writelines(expanded_lines)
if __name__ == '__main__':
main(sys.argv)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the OpenImages label expansion (OIDHierarchicalLabelsExpansion)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from object_detection.dataset_tools import oid_hierarchical_labels_expansion
def create_test_data():
hierarchy = {
'LabelName':
'a',
'Subcategory': [{
'LabelName': 'b'
}, {
'LabelName': 'c',
'Subcategory': [{
'LabelName': 'd'
}, {
'LabelName': 'e'
}]
}, {
'LabelName': 'f',
'Subcategory': [{
'LabelName': 'd'
},]
}]
}
bbox_rows = [
'123,xclick,b,1,0.1,0.2,0.1,0.2,1,1,0,0,0',
'123,xclick,d,1,0.2,0.3,0.1,0.2,1,1,0,0,0'
]
label_rows = [
'123,verification,b,0', '123,verification,c,0', '124,verification,d,1'
]
return hierarchy, bbox_rows, label_rows
class HierarchicalLabelsExpansionTest(tf.test.TestCase):
def test_bbox_expansion(self):
hierarchy, bbox_rows, _ = create_test_data()
expansion_generator = (
oid_hierarchical_labels_expansion.OIDHierarchicalLabelsExpansion(
hierarchy))
all_result_rows = []
for row in bbox_rows:
all_result_rows.extend(expansion_generator.expand_boxes_from_csv(row))
self.assertItemsEqual([
'123,xclick,b,1,0.1,0.2,0.1,0.2,1,1,0,0,0',
'123,xclick,d,1,0.2,0.3,0.1,0.2,1,1,0,0,0',
'123,xclick,f,1,0.2,0.3,0.1,0.2,1,1,0,0,0',
'123,xclick,c,1,0.2,0.3,0.1,0.2,1,1,0,0,0'
], all_result_rows)
def test_labels_expansion(self):
hierarchy, _, label_rows = create_test_data()
expansion_generator = (
oid_hierarchical_labels_expansion.OIDHierarchicalLabelsExpansion(
hierarchy))
all_result_rows = []
for row in label_rows:
all_result_rows.extend(expansion_generator.expand_labels_from_csv(row))
self.assertItemsEqual([
'123,verification,b,0', '123,verification,c,0', '123,verification,d,0',
'123,verification,e,0', '124,verification,d,1', '124,verification,f,1',
'124,verification,c,1'
], all_result_rows)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment