Commit c04eea1e authored by Jonathan Huang's avatar Jonathan Huang Committed by TF Object Detection Team
Browse files

Add support in label map proto for LVIS specific fields.

PiperOrigin-RevId: 338526646
parent aca137c1
...@@ -6,6 +6,14 @@ syntax = "proto2"; ...@@ -6,6 +6,14 @@ syntax = "proto2";
package object_detection.protos; package object_detection.protos;
// LVIS frequency:
enum LVISFrequency {
UNSPECIFIED = 0;
FREQUENT = 1;
COMMON = 2;
RARE = 3;
}
message StringIntLabelMapItem { message StringIntLabelMapItem {
// String name. The most common practice is to set this to a MID or synsets // String name. The most common practice is to set this to a MID or synsets
// id. // id.
...@@ -38,6 +46,10 @@ message StringIntLabelMapItem { ...@@ -38,6 +46,10 @@ message StringIntLabelMapItem {
// current element. Value should correspond to another label id element. // current element. Value should correspond to another label id element.
repeated int32 ancestor_ids = 5; repeated int32 ancestor_ids = 5;
repeated int32 descendant_ids = 6; repeated int32 descendant_ids = 6;
// LVIS specific label map fields
optional LVISFrequency frequency = 7;
optional int32 instance_count = 8;
}; };
message StringIntLabelMap { message StringIntLabelMap {
......
...@@ -130,6 +130,18 @@ def convert_label_map_to_categories(label_map, ...@@ -130,6 +130,18 @@ def convert_label_map_to_categories(label_map,
if item.id not in list_of_ids_already_added: if item.id not in list_of_ids_already_added:
list_of_ids_already_added.append(item.id) list_of_ids_already_added.append(item.id)
category = {'id': item.id, 'name': name} category = {'id': item.id, 'name': name}
if item.HasField('frequency'):
if item.frequency == string_int_label_map_pb2.LVISFrequency.Value(
'FREQUENT'):
category['frequency'] = 'f'
elif item.frequency == string_int_label_map_pb2.LVISFrequency.Value(
'COMMON'):
category['frequency'] = 'c'
elif item.frequency == string_int_label_map_pb2.LVISFrequency.Value(
'RARE'):
category['frequency'] = 'r'
if item.HasField('instance_count'):
category['instance_count'] = item.instance_count
if item.keypoints: if item.keypoints:
keypoints = {} keypoints = {}
list_of_keypoint_ids = [] list_of_keypoint_ids = []
......
...@@ -201,7 +201,7 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -201,7 +201,7 @@ class LabelMapUtilTest(tf.test.TestCase):
name:'n00007846' name:'n00007846'
} }
""" """
text_format.Merge(label_map_string, label_map_proto) text_format.Parse(label_map_string, label_map_proto)
categories = label_map_util.convert_label_map_to_categories( categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3) label_map_proto, max_num_classes=3)
self.assertListEqual([{ self.assertListEqual([{
...@@ -227,19 +227,61 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -227,19 +227,61 @@ class LabelMapUtilTest(tf.test.TestCase):
}] }]
self.assertListEqual(expected_categories_list, categories) self.assertListEqual(expected_categories_list, categories)
def test_convert_label_map_to_categories_lvis_frequency_and_counts(self):
label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
label_map_string = """
item {
id:1
name:'person'
frequency: FREQUENT
instance_count: 1000
}
item {
id:2
name:'dog'
frequency: COMMON
instance_count: 100
}
item {
id:3
name:'cat'
frequency: RARE
instance_count: 10
}
"""
text_format.Parse(label_map_string, label_map_proto)
categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3)
self.assertListEqual([{
'id': 1,
'name': u'person',
'frequency': 'f',
'instance_count': 1000
}, {
'id': 2,
'name': u'dog',
'frequency': 'c',
'instance_count': 100
}, {
'id': 3,
'name': u'cat',
'frequency': 'r',
'instance_count': 10
}], categories)
def test_convert_label_map_to_categories(self): def test_convert_label_map_to_categories(self):
label_map_proto = self._generate_label_map(num_classes=4) label_map_proto = self._generate_label_map(num_classes=4)
categories = label_map_util.convert_label_map_to_categories( categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3) label_map_proto, max_num_classes=3)
expected_categories_list = [{ expected_categories_list = [{
'name': u'1', 'name': u'1',
'id': 1 'id': 1,
}, { }, {
'name': u'2', 'name': u'2',
'id': 2 'id': 2,
}, { }, {
'name': u'3', 'name': u'3',
'id': 3 'id': 3,
}] }]
self.assertListEqual(expected_categories_list, categories) self.assertListEqual(expected_categories_list, categories)
...@@ -259,7 +301,7 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -259,7 +301,7 @@ class LabelMapUtilTest(tf.test.TestCase):
} }
""" """
label_map_proto = string_int_label_map_pb2.StringIntLabelMap() label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
text_format.Merge(label_map_str, label_map_proto) text_format.Parse(label_map_str, label_map_proto)
categories = label_map_util.convert_label_map_to_categories( categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=1) label_map_proto, max_num_classes=1)
self.assertEqual('person', categories[0]['name']) self.assertEqual('person', categories[0]['name'])
...@@ -291,7 +333,7 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -291,7 +333,7 @@ class LabelMapUtilTest(tf.test.TestCase):
} }
""" """
label_map_proto = string_int_label_map_pb2.StringIntLabelMap() label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
text_format.Merge(label_map_str, label_map_proto) text_format.Parse(label_map_str, label_map_proto)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
label_map_util.convert_label_map_to_categories( label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=2) label_map_proto, max_num_classes=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment