Commit 571880ce authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Filter our unrecognized `image_classes_field` entries (i.e -1s)

PiperOrigin-RevId: 336321500
parent 17f2f812
...@@ -88,7 +88,8 @@ def _convert_labeled_classes_to_k_hot(groundtruth_labeled_classes, num_classes): ...@@ -88,7 +88,8 @@ def _convert_labeled_classes_to_k_hot(groundtruth_labeled_classes, num_classes):
def _remove_unrecognized_classes(class_ids, unrecognized_label): def _remove_unrecognized_classes(class_ids, unrecognized_label):
"""Returns class ids with unrecognized classes filtered out.""" """Returns class ids with unrecognized classes filtered out."""
recognized_indices = tf.where(tf.greater(class_ids, unrecognized_label)) recognized_indices = tf.squeeze(
tf.where(tf.greater(class_ids, unrecognized_label)), -1)
return tf.gather(class_ids, recognized_indices) return tf.gather(class_ids, recognized_indices)
...@@ -213,6 +214,8 @@ def transform_input_data(tensor_dict, ...@@ -213,6 +214,8 @@ def transform_input_data(tensor_dict,
out_tensor_dict[labeled_classes_field], num_classes) out_tensor_dict[labeled_classes_field], num_classes)
if image_classes_field in out_tensor_dict: if image_classes_field in out_tensor_dict:
out_tensor_dict[image_classes_field] = _remove_unrecognized_classes(
out_tensor_dict[image_classes_field], unrecognized_label=-1)
out_tensor_dict[labeled_classes_field] = _convert_labeled_classes_to_k_hot( out_tensor_dict[labeled_classes_field] = _convert_labeled_classes_to_k_hot(
out_tensor_dict[image_classes_field], num_classes) out_tensor_dict[image_classes_field], num_classes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment