Commit ab0aa4ef authored by liukuikun's avatar liukuikun Committed by zhouzaida
Browse files

load type (#1921)

parent f59aec8f
......@@ -120,13 +120,13 @@ class LoadAnnotations(BaseTransform):
{
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in an image
# in np.float32
'gt_bboxes': np.ndarray(N, 4)
# In int type.
# In np.int32 type.
'gt_bboxes_labels': np.ndarray(N, )
# In uint8 type.
'gt_seg_map': np.ndarray (H, W)
# in (x, y, v) order, float type.
# with (x, y, v) order, in np.float32 type.
'gt_keypoints': np.ndarray(N, NK, 3)
}
......@@ -142,10 +142,10 @@ class LoadAnnotations(BaseTransform):
Added Keys:
- gt_bboxes
- gt_bboxes_labels
- gt_seg_map
- gt_keypoints
- gt_bboxes (np.float32)
- gt_bboxes_labels (np.int32)
- gt_seg_map (np.uint8)
- gt_keypoints (np.float32)
Args:
with_bbox (bool): Whether to parse and load the bbox annotation.
......@@ -194,7 +194,7 @@ class LoadAnnotations(BaseTransform):
gt_bboxes = []
for instance in results['instances']:
gt_bboxes.append(instance['bbox'])
results['gt_bboxes'] = np.array(gt_bboxes)
results['gt_bboxes'] = np.array(gt_bboxes, dtype=np.float32)
def _load_labels(self, results: dict) -> None:
"""Private function to load label annotations.
......@@ -208,7 +208,8 @@ class LoadAnnotations(BaseTransform):
gt_bboxes_labels = []
for instance in results['instances']:
gt_bboxes_labels.append(instance['bbox_label'])
results['gt_bboxes_labels'] = np.array(gt_bboxes_labels)
results['gt_bboxes_labels'] = np.array(
gt_bboxes_labels, dtype=np.int32)
def _load_seg_map(self, results: dict) -> None:
"""Private function to load semantic segmentation annotations.
......@@ -236,7 +237,7 @@ class LoadAnnotations(BaseTransform):
gt_keypoints = []
for instance in results['instances']:
gt_keypoints.append(instance['keypoints'])
results['gt_keypoints'] = np.array(gt_keypoints).reshape(
results['gt_keypoints'] = np.array(gt_keypoints, np.float32).reshape(
(len(gt_keypoints), -1, 3))
def transform(self, results: dict) -> dict:
......
......@@ -74,6 +74,7 @@ class TestLoadAnnotations:
assert 'gt_bboxes' in results
assert (results['gt_bboxes'] == np.array([[0, 0, 10, 20],
[10, 10, 110, 120]])).all()
assert results['gt_bboxes'].dtype == np.float32
def test_load_labels(self):
transform = LoadAnnotations(
......@@ -85,6 +86,7 @@ class TestLoadAnnotations:
results = transform(copy.deepcopy(self.results))
assert 'gt_bboxes_labels' in results
assert (results['gt_bboxes_labels'] == np.array([1, 2])).all()
assert results['gt_bboxes_labels'].dtype == np.int32
def test_load_kps(self):
transform = LoadAnnotations(
......@@ -97,6 +99,7 @@ class TestLoadAnnotations:
assert 'gt_keypoints' in results
assert (results['gt_keypoints'] == np.array([[[1, 2, 3]],
[[4, 5, 6]]])).all()
assert results['gt_keypoints'].dtype == np.float32
def test_load_seg_map(self):
transform = LoadAnnotations(
......@@ -108,6 +111,7 @@ class TestLoadAnnotations:
results = transform(copy.deepcopy(self.results))
assert 'gt_seg_map' in results
assert results['gt_seg_map'].shape[:2] == (300, 400)
assert results['gt_seg_map'].dtype == np.uint8
def test_repr(self):
transform = LoadAnnotations(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment