"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "6d89995687e7853946175314a893a11fdc695a0c"
Commit b9c61118 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Added the get_keypoint_label_map_dict function to prepare the merged dictionary

that maps from the keypoint names to their ids.

PiperOrigin-RevId: 400009398
parent c1ff413a
...@@ -232,6 +232,39 @@ def get_label_map_dict(label_map_path_or_proto, ...@@ -232,6 +232,39 @@ def get_label_map_dict(label_map_path_or_proto,
return label_map_dict return label_map_dict
def get_keypoint_label_map_dict(label_map_path_or_proto):
"""Reads a label map and returns a dictionary of keypoint names to ids.
Note that the keypoints belong to different classes will be merged into a
single dictionary. It is expected that there is no duplicated keypoint names
or ids from different classes.
Args:
label_map_path_or_proto: path to StringIntLabelMap proto text file or the
proto itself.
Returns:
A dictionary mapping keypoint names to the keypoint id (not the object id).
Raises:
ValueError: if there are duplicated keyoint names or ids.
"""
if isinstance(label_map_path_or_proto, string_types):
label_map = load_labelmap(label_map_path_or_proto)
else:
label_map = label_map_path_or_proto
label_map_dict = {}
for item in label_map.item:
for kpts in item.keypoints:
if kpts.label in label_map_dict.keys():
raise ValueError('Duplicated keypoint label: %s' % kpts.label)
if kpts.id in label_map_dict.values():
raise ValueError('Duplicated keypoint ID: %d' % kpts.id)
label_map_dict[kpts.label] = kpts.id
return label_map_dict
def get_label_map_hierarchy_lut(label_map_path_or_proto, def get_label_map_hierarchy_lut(label_map_path_or_proto,
include_identity=False): include_identity=False):
"""Reads a label map and returns ancestors and descendants in the hierarchy. """Reads a label map and returns ancestors and descendants in the hierarchy.
......
...@@ -74,6 +74,82 @@ class LabelMapUtilTest(tf.test.TestCase): ...@@ -74,6 +74,82 @@ class LabelMapUtilTest(tf.test.TestCase):
self.assertEqual(label_map_dict['dog'], 1) self.assertEqual(label_map_dict['dog'], 1)
self.assertEqual(label_map_dict['cat'], 2) self.assertEqual(label_map_dict['cat'], 2)
def test_get_keypoint_label_map_dict(self):
label_map_string = """
item: {
id: 1
name: 'face'
display_name: 'face'
keypoints {
id: 0
label: 'left_eye'
}
keypoints {
id: 1
label: 'right_eye'
}
}
item: {
id: 2
name: '/m/01g317'
display_name: 'person'
keypoints {
id: 2
label: 'left_shoulder'
}
keypoints {
id: 3
label: 'right_shoulder'
}
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
label_map_dict = label_map_util.get_keypoint_label_map_dict(label_map_path)
self.assertEqual(label_map_dict['left_eye'], 0)
self.assertEqual(label_map_dict['right_eye'], 1)
self.assertEqual(label_map_dict['left_shoulder'], 2)
self.assertEqual(label_map_dict['right_shoulder'], 3)
def test_get_keypoint_label_map_dict_invalid(self):
label_map_string = """
item: {
id: 1
name: 'face'
display_name: 'face'
keypoints {
id: 0
label: 'left_eye'
}
keypoints {
id: 1
label: 'right_eye'
}
}
item: {
id: 2
name: '/m/01g317'
display_name: 'person'
keypoints {
id: 0
label: 'left_shoulder'
}
keypoints {
id: 1
label: 'right_shoulder'
}
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
with self.assertRaises(ValueError):
_ = label_map_util.get_keypoint_label_map_dict(
label_map_path)
def test_get_label_map_dict_from_proto(self): def test_get_label_map_dict_from_proto(self):
label_map_string = """ label_map_string = """
item { item {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment