Commit 9139a7b9 authored by Jonathan Huang's avatar Jonathan Huang Committed by TF Object Detection Team
Browse files

Plumb LVIS specific fields (e.g. `neg_category_ids`,...

Plumb LVIS specific fields (e.g. `neg_category_ids`, `not_exhaustive_category_ids`) through input pipelines.

PiperOrigin-RevId: 339614575
parent 24e41ffe
...@@ -313,7 +313,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -313,7 +313,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_group_of_list=None, groundtruth_group_of_list=None,
groundtruth_area_list=None, groundtruth_area_list=None,
is_annotated_list=None, is_annotated_list=None,
groundtruth_labeled_classes=None): groundtruth_labeled_classes=None,
groundtruth_verified_neg_classes=None,
groundtruth_not_exhaustive_classes=None):
"""Provide groundtruth tensors. """Provide groundtruth tensors.
Args: Args:
...@@ -371,6 +373,12 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -371,6 +373,12 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_labeled_classes: A list of 1-D tf.float32 tensors of shape groundtruth_labeled_classes: A list of 1-D tf.float32 tensors of shape
[num_classes], containing label indices encoded as k-hot of the classes [num_classes], containing label indices encoded as k-hot of the classes
that are exhaustively annotated. that are exhaustively annotated.
groundtruth_verified_neg_classes: A list of 1-D tf.float32 tensors of
shape [num_classes], containing a K-hot representation of classes
which were verified as not present in the image.
groundtruth_not_exhaustive_classes: A list of 1-D tf.float32 tensors of
shape [num_classes], containing a K-hot representation of classes
which don't have all of their instances marked exhaustively.
""" """
self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list
self._groundtruth_lists[ self._groundtruth_lists[
...@@ -430,6 +438,15 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -430,6 +438,15 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
self._groundtruth_lists[ self._groundtruth_lists[
fields.InputDataFields fields.InputDataFields
.groundtruth_labeled_classes] = groundtruth_labeled_classes .groundtruth_labeled_classes] = groundtruth_labeled_classes
if groundtruth_verified_neg_classes:
self._groundtruth_lists[
fields.InputDataFields
.groundtruth_verified_neg_classes] = groundtruth_verified_neg_classes
if groundtruth_not_exhaustive_classes:
self._groundtruth_lists[
fields.InputDataFields
.groundtruth_not_exhaustive_classes] = (
groundtruth_not_exhaustive_classes)
@abc.abstractmethod @abc.abstractmethod
def regularization_losses(self): def regularization_losses(self):
......
...@@ -203,6 +203,10 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -203,6 +203,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
tf.VarLenFeature(tf.string), tf.VarLenFeature(tf.string),
'image/class/label': 'image/class/label':
tf.VarLenFeature(tf.int64), tf.VarLenFeature(tf.int64),
'image/neg_category_ids':
tf.VarLenFeature(tf.int64),
'image/not_exhaustive_category_ids':
tf.VarLenFeature(tf.int64),
'image/class/confidence': 'image/class/confidence':
tf.VarLenFeature(tf.float32), tf.VarLenFeature(tf.float32),
# Object boxes and classes. # Object boxes and classes.
...@@ -264,6 +268,10 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -264,6 +268,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
# Image-level labels. # Image-level labels.
fields.InputDataFields.groundtruth_image_confidences: ( fields.InputDataFields.groundtruth_image_confidences: (
slim_example_decoder.Tensor('image/class/confidence')), slim_example_decoder.Tensor('image/class/confidence')),
fields.InputDataFields.groundtruth_verified_neg_classes: (
slim_example_decoder.Tensor('image/neg_category_ids')),
fields.InputDataFields.groundtruth_not_exhaustive_classes: (
slim_example_decoder.Tensor('image/not_exhaustive_category_ids')),
# Object boxes and classes. # Object boxes and classes.
fields.InputDataFields.groundtruth_boxes: ( fields.InputDataFields.groundtruth_boxes: (
slim_example_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'], slim_example_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
......
...@@ -841,6 +841,61 @@ class TfExampleDecoderTest(test_case.TestCase): ...@@ -841,6 +841,61 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertAllEqual(object_area, self.assertAllEqual(object_area,
tensor_dict[fields.InputDataFields.groundtruth_area]) tensor_dict[fields.InputDataFields.groundtruth_area])
def testDecodeVerifiedNegClasses(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
neg_category_ids = [0, 5, 8]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/neg_category_ids':
dataset_util.int64_list_feature(neg_category_ids),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
output = example_decoder.decode(tf.convert_to_tensor(example))
return output
tensor_dict = self.execute_cpu(graph_fn, [])
self.assertAllEqual(
neg_category_ids,
tensor_dict[fields.InputDataFields.groundtruth_verified_neg_classes])
def testDecodeNotExhaustiveClasses(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
not_exhaustive_category_ids = [0, 5, 8]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/not_exhaustive_category_ids':
dataset_util.int64_list_feature(
not_exhaustive_category_ids),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
output = example_decoder.decode(tf.convert_to_tensor(example))
return output
tensor_dict = self.execute_cpu(graph_fn, [])
self.assertAllEqual(
not_exhaustive_category_ids,
tensor_dict[fields.InputDataFields.groundtruth_not_exhaustive_classes])
def testDecodeObjectIsCrowd(self): def testDecodeObjectIsCrowd(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data( encoded_jpeg, _ = self._create_encoded_and_decoded_data(
......
...@@ -33,6 +33,7 @@ from object_detection.core import box_list_ops ...@@ -33,6 +33,7 @@ from object_detection.core import box_list_ops
from object_detection.core import keypoint_ops from object_detection.core import keypoint_ops
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.metrics import coco_evaluation from object_detection.metrics import coco_evaluation
from object_detection.metrics import lvis_evaluation
from object_detection.protos import eval_pb2 from object_detection.protos import eval_pb2
from object_detection.utils import label_map_util from object_detection.utils import label_map_util
from object_detection.utils import object_detection_evaluation from object_detection.utils import object_detection_evaluation
...@@ -54,6 +55,8 @@ EVAL_METRICS_CLASS_DICT = { ...@@ -54,6 +55,8 @@ EVAL_METRICS_CLASS_DICT = {
coco_evaluation.CocoMaskEvaluator, coco_evaluation.CocoMaskEvaluator,
'coco_panoptic_metrics': 'coco_panoptic_metrics':
coco_evaluation.CocoPanopticSegmentationEvaluator, coco_evaluation.CocoPanopticSegmentationEvaluator,
'lvis_mask_metrics':
lvis_evaluation.LVISMaskEvaluator,
'oid_challenge_detection_metrics': 'oid_challenge_detection_metrics':
object_detection_evaluation.OpenImagesDetectionChallengeEvaluator, object_detection_evaluation.OpenImagesDetectionChallengeEvaluator,
'oid_challenge_segmentation_metrics': 'oid_challenge_segmentation_metrics':
...@@ -548,10 +551,36 @@ def _scale_box_to_absolute(args): ...@@ -548,10 +551,36 @@ def _scale_box_to_absolute(args):
box_list.BoxList(boxes), image_shape[0], image_shape[1]).get() box_list.BoxList(boxes), image_shape[0], image_shape[1]).get()
def _resize_detection_masks(args): def _resize_detection_masks(arg_tuple):
detection_boxes, detection_masks, image_shape = args """Resizes detection masks.
Args:
arg_tuple: A (detection_boxes, detection_masks, image_shape, pad_shape)
tuple where
detection_boxes is a tf.float32 tensor of size [num_masks, 4] containing
the box corners. Row i contains [ymin, xmin, ymax, xmax] of the box
corresponding to mask i. Note that the box corners are in
normalized coordinates.
detection_masks is a tensor of size
[num_masks, mask_height, mask_width].
image_shape is a tensor of shape [2]
pad_shape is a tensor of shape [2] --- this is assumed to be greater
than or equal to image_shape along both dimensions and represents a
shape to-be-padded-to.
Returns:
"""
detection_boxes, detection_masks, image_shape, pad_shape = arg_tuple
detection_masks_reframed = ops.reframe_box_masks_to_image_masks( detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
detection_masks, detection_boxes, image_shape[0], image_shape[1]) detection_masks, detection_boxes, image_shape[0], image_shape[1])
paddings = tf.concat(
[tf.zeros([3, 1], dtype=tf.int32),
tf.expand_dims(
tf.concat([tf.zeros([1], dtype=tf.int32),
pad_shape-image_shape], axis=0),
1)], axis=1)
detection_masks_reframed = tf.pad(detection_masks_reframed, paddings)
# If the masks are currently float, binarize them. Otherwise keep them as # If the masks are currently float, binarize them. Otherwise keep them as
# integers, since they have already been thresholded. # integers, since they have already been thresholded.
if detection_masks_reframed.dtype == tf.float32: if detection_masks_reframed.dtype == tf.float32:
...@@ -569,7 +598,7 @@ def resize_detection_masks(detection_boxes, detection_masks, ...@@ -569,7 +598,7 @@ def resize_detection_masks(detection_boxes, detection_masks,
Args: Args:
detection_boxes: A [batch_size, num_instances, 4] float tensor containing detection_boxes: A [batch_size, num_instances, 4] float tensor containing
bounding boxes. bounding boxes.
detection_masks: A [batch_suze, num_instances, height, width] float tensor detection_masks: A [batch_size, num_instances, height, width] float tensor
containing binary instance masks per box. containing binary instance masks per box.
original_image_spatial_shapes: a [batch_size, 3] shaped int tensor original_image_spatial_shapes: a [batch_size, 3] shaped int tensor
holding the spatial dimensions of each image in the batch. holding the spatial dimensions of each image in the batch.
...@@ -577,15 +606,26 @@ def resize_detection_masks(detection_boxes, detection_masks, ...@@ -577,15 +606,26 @@ def resize_detection_masks(detection_boxes, detection_masks,
masks: Masks resized to the spatial extents given by masks: Masks resized to the spatial extents given by
(original_image_spatial_shapes[0, 0], original_image_spatial_shapes[0, 1]) (original_image_spatial_shapes[0, 0], original_image_spatial_shapes[0, 1])
""" """
# modify original image spatial shapes to be max along each dim
# in evaluator, should have access to original_image_spatial_shape field
# in add_Eval_Dict
max_spatial_shape = tf.reduce_max(
original_image_spatial_shapes, axis=0, keep_dims=True)
tiled_max_spatial_shape = tf.tile(
max_spatial_shape,
multiples=[tf.shape(original_image_spatial_shapes)[0], 1])
return shape_utils.static_or_dynamic_map_fn( return shape_utils.static_or_dynamic_map_fn(
_resize_detection_masks, _resize_detection_masks,
elems=[detection_boxes, detection_masks, original_image_spatial_shapes], elems=[detection_boxes,
detection_masks,
original_image_spatial_shapes,
tiled_max_spatial_shape],
dtype=tf.uint8) dtype=tf.uint8)
def _resize_groundtruth_masks(args): def _resize_groundtruth_masks(args):
"""Resizes groundgtruth masks to the original image size.""" """Resizes groundtruth masks to the original image size."""
mask, true_image_shape, original_image_shape = args mask, true_image_shape, original_image_shape, pad_shape = args
true_height = true_image_shape[0] true_height = true_image_shape[0]
true_width = true_image_shape[1] true_width = true_image_shape[1]
mask = mask[:, :true_height, :true_width] mask = mask[:, :true_height, :true_width]
...@@ -595,7 +635,15 @@ def _resize_groundtruth_masks(args): ...@@ -595,7 +635,15 @@ def _resize_groundtruth_masks(args):
original_image_shape, original_image_shape,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True) align_corners=True)
return tf.cast(tf.squeeze(mask, 3), tf.uint8)
paddings = tf.concat(
[tf.zeros([3, 1], dtype=tf.int32),
tf.expand_dims(
tf.concat([tf.zeros([1], dtype=tf.int32),
pad_shape-original_image_shape], axis=0),
1)], axis=1)
mask = tf.pad(tf.squeeze(mask, 3), paddings)
return tf.cast(mask, tf.uint8)
def _resize_surface_coordinate_masks(args): def _resize_surface_coordinate_masks(args):
...@@ -932,10 +980,17 @@ def result_dict_for_batched_example(images, ...@@ -932,10 +980,17 @@ def result_dict_for_batched_example(images,
if input_data_fields.groundtruth_instance_masks in groundtruth: if input_data_fields.groundtruth_instance_masks in groundtruth:
masks = groundtruth[input_data_fields.groundtruth_instance_masks] masks = groundtruth[input_data_fields.groundtruth_instance_masks]
max_spatial_shape = tf.reduce_max(
original_image_spatial_shapes, axis=0, keep_dims=True)
tiled_max_spatial_shape = tf.tile(
max_spatial_shape,
multiples=[tf.shape(original_image_spatial_shapes)[0], 1])
groundtruth[input_data_fields.groundtruth_instance_masks] = ( groundtruth[input_data_fields.groundtruth_instance_masks] = (
shape_utils.static_or_dynamic_map_fn( shape_utils.static_or_dynamic_map_fn(
_resize_groundtruth_masks, _resize_groundtruth_masks,
elems=[masks, true_image_shapes, original_image_spatial_shapes], elems=[masks, true_image_shapes,
original_image_spatial_shapes,
tiled_max_spatial_shape],
dtype=tf.uint8)) dtype=tf.uint8))
output_dict.update(groundtruth) output_dict.update(groundtruth)
...@@ -1116,7 +1171,8 @@ def evaluator_options_from_eval_config(eval_config): ...@@ -1116,7 +1171,8 @@ def evaluator_options_from_eval_config(eval_config):
eval_metric_fn_keys = eval_config.metrics_set eval_metric_fn_keys = eval_config.metrics_set
evaluator_options = {} evaluator_options = {}
for eval_metric_fn_key in eval_metric_fn_keys: for eval_metric_fn_key in eval_metric_fn_keys:
if eval_metric_fn_key in ('coco_detection_metrics', 'coco_mask_metrics'): if eval_metric_fn_key in (
'coco_detection_metrics', 'coco_mask_metrics', 'lvis_mask_metrics'):
evaluator_options[eval_metric_fn_key] = { evaluator_options[eval_metric_fn_key] = {
'include_metrics_per_category': ( 'include_metrics_per_category': (
eval_config.include_metrics_per_category) eval_config.include_metrics_per_category)
......
This diff is collapsed.
...@@ -108,6 +108,12 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic, ...@@ -108,6 +108,12 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
group_of annotations (if provided in groundtruth). group_of annotations (if provided in groundtruth).
'groundtruth_labeled_classes': [batch_size, num_classes] int64 'groundtruth_labeled_classes': [batch_size, num_classes] int64
tensor of 1-indexed classes. tensor of 1-indexed classes.
'groundtruth_verified_neg_classes': [batch_size, num_classes] float32
K-hot representation of 1-indexed classes which were verified as not
present in the image.
'groundtruth_not_exhaustive_classes': [batch_size, num_classes] K-hot
representation of 1-indexed classes which don't have all of their
instances marked exhaustively.
class_agnostic: Boolean indicating whether detections are class agnostic. class_agnostic: Boolean indicating whether detections are class agnostic.
""" """
input_data_fields = fields.InputDataFields() input_data_fields = fields.InputDataFields()
...@@ -129,6 +135,7 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic, ...@@ -129,6 +135,7 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
input_data_fields.groundtruth_boxes: groundtruth_boxes, input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes input_data_fields.groundtruth_classes: groundtruth_classes
} }
if detection_model.groundtruth_has_field(fields.BoxListFields.masks): if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack( groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.masks)) detection_model.groundtruth_lists(fields.BoxListFields.masks))
...@@ -156,23 +163,17 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic, ...@@ -156,23 +163,17 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
detection_model.groundtruth_lists(fields.BoxListFields.group_of)) detection_model.groundtruth_lists(fields.BoxListFields.group_of))
if detection_model.groundtruth_has_field( if detection_model.groundtruth_has_field(
fields.InputDataFields.groundtruth_labeled_classes): input_data_fields.groundtruth_verified_neg_classes):
labeled_classes_list = detection_model.groundtruth_lists( groundtruth[input_data_fields.groundtruth_verified_neg_classes] = tf.stack(
fields.InputDataFields.groundtruth_labeled_classes) detection_model.groundtruth_lists(
labeled_classes = [ input_data_fields.groundtruth_verified_neg_classes))
tf.where(x)[:, 0] + label_id_offset for x in labeled_classes_list
] if detection_model.groundtruth_has_field(
if len(labeled_classes) > 1: input_data_fields.groundtruth_not_exhaustive_classes):
num_classes = labeled_classes_list[0].shape[0] groundtruth[
padded_labeled_classes = [] input_data_fields.groundtruth_not_exhaustive_classes] = tf.stack(
for x in labeled_classes: detection_model.groundtruth_lists(
padding = num_classes - tf.shape(x)[0] input_data_fields.groundtruth_not_exhaustive_classes))
padded_labeled_classes.append(tf.pad(x, [[0, padding]]))
groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.stack(
padded_labeled_classes)
else:
groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.stack(
labeled_classes)
if detection_model.groundtruth_has_field( if detection_model.groundtruth_has_field(
fields.BoxListFields.densepose_num_points): fields.BoxListFields.densepose_num_points):
...@@ -194,6 +195,25 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic, ...@@ -194,6 +195,25 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
groundtruth[input_data_fields.groundtruth_track_ids] = tf.stack( groundtruth[input_data_fields.groundtruth_track_ids] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.track_ids)) detection_model.groundtruth_lists(fields.BoxListFields.track_ids))
if detection_model.groundtruth_has_field(
input_data_fields.groundtruth_labeled_classes):
labeled_classes_list = detection_model.groundtruth_lists(
input_data_fields.groundtruth_labeled_classes)
labeled_classes = [
tf.where(x)[:, 0] + label_id_offset for x in labeled_classes_list
]
if len(labeled_classes) > 1:
num_classes = labeled_classes_list[0].shape[0]
padded_labeled_classes = []
for x in labeled_classes:
padding = num_classes - tf.shape(x)[0]
padded_labeled_classes.append(tf.pad(x, [[0, padding]]))
groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.stack(
padded_labeled_classes)
else:
groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.stack(
labeled_classes)
groundtruth[input_data_fields.num_groundtruth_boxes] = ( groundtruth[input_data_fields.num_groundtruth_boxes] = (
tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]])) tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
return groundtruth return groundtruth
...@@ -339,6 +359,14 @@ def provide_groundtruth(model, labels): ...@@ -339,6 +359,14 @@ def provide_groundtruth(model, labels):
if fields.InputDataFields.groundtruth_labeled_classes in labels: if fields.InputDataFields.groundtruth_labeled_classes in labels:
gt_labeled_classes = labels[ gt_labeled_classes = labels[
fields.InputDataFields.groundtruth_labeled_classes] fields.InputDataFields.groundtruth_labeled_classes]
gt_verified_neg_classes = None
if fields.InputDataFields.groundtruth_verified_neg_classes in labels:
gt_verified_neg_classes = labels[
fields.InputDataFields.groundtruth_verified_neg_classes]
gt_not_exhaustive_classes = None
if fields.InputDataFields.groundtruth_not_exhaustive_classes in labels:
gt_not_exhaustive_classes = labels[
fields.InputDataFields.groundtruth_not_exhaustive_classes]
model.provide_groundtruth( model.provide_groundtruth(
groundtruth_boxes_list=gt_boxes_list, groundtruth_boxes_list=gt_boxes_list,
groundtruth_classes_list=gt_classes_list, groundtruth_classes_list=gt_classes_list,
...@@ -354,7 +382,9 @@ def provide_groundtruth(model, labels): ...@@ -354,7 +382,9 @@ def provide_groundtruth(model, labels):
groundtruth_is_crowd_list=gt_is_crowd_list, groundtruth_is_crowd_list=gt_is_crowd_list,
groundtruth_group_of_list=gt_group_of_list, groundtruth_group_of_list=gt_group_of_list,
groundtruth_area_list=gt_area_list, groundtruth_area_list=gt_area_list,
groundtruth_track_ids_list=gt_track_ids_list) groundtruth_track_ids_list=gt_track_ids_list,
groundtruth_verified_neg_classes=gt_verified_neg_classes,
groundtruth_not_exhaustive_classes=gt_not_exhaustive_classes)
def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
......
...@@ -703,6 +703,7 @@ def eager_eval_loop( ...@@ -703,6 +703,7 @@ def eager_eval_loop(
evaluator_options = eval_util.evaluator_options_from_eval_config( evaluator_options = eval_util.evaluator_options_from_eval_config(
eval_config) eval_config)
batch_size = eval_config.batch_size
class_agnostic_category_index = ( class_agnostic_category_index = (
label_map_util.create_class_agnostic_category_index()) label_map_util.create_class_agnostic_category_index())
...@@ -731,7 +732,9 @@ def eager_eval_loop( ...@@ -731,7 +732,9 @@ def eager_eval_loop(
# must be unpadded. # must be unpadded.
boxes_shape = ( boxes_shape = (
labels[fields.InputDataFields.groundtruth_boxes].get_shape().as_list()) labels[fields.InputDataFields.groundtruth_boxes].get_shape().as_list())
unpad_groundtruth_tensors = boxes_shape[1] is not None and not use_tpu unpad_groundtruth_tensors = (boxes_shape[1] is not None
and not use_tpu
and batch_size == 1)
labels = model_lib.unstack_batch( labels = model_lib.unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
...@@ -799,7 +802,8 @@ def eager_eval_loop( ...@@ -799,7 +802,8 @@ def eager_eval_loop(
tf.logging.info('Finished eval step %d', i) tf.logging.info('Finished eval step %d', i)
use_original_images = fields.InputDataFields.original_image in features use_original_images = fields.InputDataFields.original_image in features
if use_original_images and i < eval_config.num_visualizations: if (use_original_images and i < eval_config.num_visualizations
and batch_size == 1):
sbys_image_list = vutils.draw_side_by_side_evaluation_image( sbys_image_list = vutils.draw_side_by_side_evaluation_image(
eval_dict, eval_dict,
category_index=category_index, category_index=category_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment