Commit 33cca6c3 authored by Jonathan Huang's avatar Jonathan Huang Committed by TF Object Detection Team
Browse files

Fixes two bugs when handling verified_neg_classes and not_exhaustive_classes...

Fixes two bugs when handling verified_neg_classes and not_exhaustive_classes fields for LVIS evaluation:
(1) Before, if one of these fields was empty, the input pipeline would default to an all-ones representation; this CL turns out off this behavior for the LVIS-specific classes.
(2) Labels are now 1-indexed coming out of the _prepare_groundtruth_for_eval function in model_lib (as they should be).

PiperOrigin-RevId: 340357845
parent 72a31e9e
......@@ -68,8 +68,24 @@ def _multiclass_scores_or_one_hot_labels(multiclass_scores,
return tf.cond(tf.size(multiclass_scores) > 0, true_fn, false_fn)
def _convert_labeled_classes_to_k_hot(groundtruth_labeled_classes, num_classes):
"""Returns k-hot encoding of the labeled classes."""
def _convert_labeled_classes_to_k_hot(groundtruth_labeled_classes, num_classes,
map_empty_to_ones=False):
"""Returns k-hot encoding of the labeled classes.
If map_empty_to_ones is enabled and the input labeled_classes is empty,
this function assumes all classes are exhaustively labeled, thus returning
an all-one encoding.
Args:
groundtruth_labeled_classes: a Tensor holding a sparse representation of
labeled classes.
num_classes: an integer representing the number of classes
map_empty_to_ones: boolean (default: False). Set this to be True to default
to an all-ones result if given an empty `groundtruth_labeled_classes`.
Returns:
A k-hot (and 0-indexed) tensor representation of
`groundtruth_labeled_classes`.
"""
# If the input labeled_classes is empty, it assumes all classes are
# exhaustively labeled, thus returning an all-one encoding.
......@@ -82,7 +98,9 @@ def _convert_labeled_classes_to_k_hot(groundtruth_labeled_classes, num_classes):
def false_fn():
return tf.ones(num_classes, dtype=tf.float32)
if map_empty_to_ones:
return tf.cond(tf.size(groundtruth_labeled_classes) > 0, true_fn, false_fn)
return true_fn()
def _remove_unrecognized_classes(class_ids, unrecognized_label):
......@@ -209,15 +227,16 @@ def transform_input_data(tensor_dict,
raise KeyError('groundtruth_labeled_classes and groundtruth_image_classes'
'are provided by the decoder, but only one should be set.')
for field in [labeled_classes_field,
image_classes_field,
verified_neg_classes_field,
not_exhaustive_field]:
for field, map_empty_to_ones in [
(labeled_classes_field, True),
(image_classes_field, True),
(verified_neg_classes_field, False),
(not_exhaustive_field, False)]:
if field in out_tensor_dict:
out_tensor_dict[field] = _remove_unrecognized_classes(
out_tensor_dict[field], unrecognized_label=-1)
out_tensor_dict[field] = _convert_labeled_classes_to_k_hot(
out_tensor_dict[field], num_classes)
out_tensor_dict[field], num_classes, map_empty_to_ones)
if input_fields.multiclass_scores in out_tensor_dict:
out_tensor_dict[
......
......@@ -162,18 +162,21 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
groundtruth[input_data_fields.groundtruth_group_of] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.group_of))
label_id_offset_paddings = tf.constant([[0, 0], [1, 0]])
if detection_model.groundtruth_has_field(
input_data_fields.groundtruth_verified_neg_classes):
groundtruth[input_data_fields.groundtruth_verified_neg_classes] = tf.stack(
detection_model.groundtruth_lists(
input_data_fields.groundtruth_verified_neg_classes))
groundtruth[input_data_fields.groundtruth_verified_neg_classes] = tf.pad(
tf.stack(detection_model.groundtruth_lists(
input_data_fields.groundtruth_verified_neg_classes)),
label_id_offset_paddings)
if detection_model.groundtruth_has_field(
input_data_fields.groundtruth_not_exhaustive_classes):
groundtruth[
input_data_fields.groundtruth_not_exhaustive_classes] = tf.stack(
detection_model.groundtruth_lists(
input_data_fields.groundtruth_not_exhaustive_classes))
input_data_fields.groundtruth_not_exhaustive_classes] = tf.pad(
tf.stack(detection_model.groundtruth_lists(
input_data_fields.groundtruth_not_exhaustive_classes)),
label_id_offset_paddings)
if detection_model.groundtruth_has_field(
fields.BoxListFields.densepose_num_points):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment