Commit ab0aa4ef authored by liukuikun's avatar liukuikun Committed by zhouzaida
Browse files

load type (#1921)

parent f59aec8f
...@@ -120,13 +120,13 @@ class LoadAnnotations(BaseTransform): ...@@ -120,13 +120,13 @@ class LoadAnnotations(BaseTransform):
{ {
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes # In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in an image # in np.float32
'gt_bboxes': np.ndarray(N, 4) 'gt_bboxes': np.ndarray(N, 4)
# In int type. # In np.int32 type.
'gt_bboxes_labels': np.ndarray(N, ) 'gt_bboxes_labels': np.ndarray(N, )
# In uint8 type. # In uint8 type.
'gt_seg_map': np.ndarray (H, W) 'gt_seg_map': np.ndarray (H, W)
# in (x, y, v) order, float type. # with (x, y, v) order, in np.float32 type.
'gt_keypoints': np.ndarray(N, NK, 3) 'gt_keypoints': np.ndarray(N, NK, 3)
} }
...@@ -142,10 +142,10 @@ class LoadAnnotations(BaseTransform): ...@@ -142,10 +142,10 @@ class LoadAnnotations(BaseTransform):
Added Keys: Added Keys:
- gt_bboxes - gt_bboxes (np.float32)
- gt_bboxes_labels - gt_bboxes_labels (np.int32)
- gt_seg_map - gt_seg_map (np.uint8)
- gt_keypoints - gt_keypoints (np.float32)
Args: Args:
with_bbox (bool): Whether to parse and load the bbox annotation. with_bbox (bool): Whether to parse and load the bbox annotation.
...@@ -194,7 +194,7 @@ class LoadAnnotations(BaseTransform): ...@@ -194,7 +194,7 @@ class LoadAnnotations(BaseTransform):
gt_bboxes = [] gt_bboxes = []
for instance in results['instances']: for instance in results['instances']:
gt_bboxes.append(instance['bbox']) gt_bboxes.append(instance['bbox'])
results['gt_bboxes'] = np.array(gt_bboxes) results['gt_bboxes'] = np.array(gt_bboxes, dtype=np.float32)
def _load_labels(self, results: dict) -> None: def _load_labels(self, results: dict) -> None:
"""Private function to load label annotations. """Private function to load label annotations.
...@@ -208,7 +208,8 @@ class LoadAnnotations(BaseTransform): ...@@ -208,7 +208,8 @@ class LoadAnnotations(BaseTransform):
gt_bboxes_labels = [] gt_bboxes_labels = []
for instance in results['instances']: for instance in results['instances']:
gt_bboxes_labels.append(instance['bbox_label']) gt_bboxes_labels.append(instance['bbox_label'])
results['gt_bboxes_labels'] = np.array(gt_bboxes_labels) results['gt_bboxes_labels'] = np.array(
gt_bboxes_labels, dtype=np.int32)
def _load_seg_map(self, results: dict) -> None: def _load_seg_map(self, results: dict) -> None:
"""Private function to load semantic segmentation annotations. """Private function to load semantic segmentation annotations.
...@@ -236,7 +237,7 @@ class LoadAnnotations(BaseTransform): ...@@ -236,7 +237,7 @@ class LoadAnnotations(BaseTransform):
gt_keypoints = [] gt_keypoints = []
for instance in results['instances']: for instance in results['instances']:
gt_keypoints.append(instance['keypoints']) gt_keypoints.append(instance['keypoints'])
results['gt_keypoints'] = np.array(gt_keypoints).reshape( results['gt_keypoints'] = np.array(gt_keypoints, np.float32).reshape(
(len(gt_keypoints), -1, 3)) (len(gt_keypoints), -1, 3))
def transform(self, results: dict) -> dict: def transform(self, results: dict) -> dict:
......
...@@ -74,6 +74,7 @@ class TestLoadAnnotations: ...@@ -74,6 +74,7 @@ class TestLoadAnnotations:
assert 'gt_bboxes' in results assert 'gt_bboxes' in results
assert (results['gt_bboxes'] == np.array([[0, 0, 10, 20], assert (results['gt_bboxes'] == np.array([[0, 0, 10, 20],
[10, 10, 110, 120]])).all() [10, 10, 110, 120]])).all()
assert results['gt_bboxes'].dtype == np.float32
def test_load_labels(self): def test_load_labels(self):
transform = LoadAnnotations( transform = LoadAnnotations(
...@@ -85,6 +86,7 @@ class TestLoadAnnotations: ...@@ -85,6 +86,7 @@ class TestLoadAnnotations:
results = transform(copy.deepcopy(self.results)) results = transform(copy.deepcopy(self.results))
assert 'gt_bboxes_labels' in results assert 'gt_bboxes_labels' in results
assert (results['gt_bboxes_labels'] == np.array([1, 2])).all() assert (results['gt_bboxes_labels'] == np.array([1, 2])).all()
assert results['gt_bboxes_labels'].dtype == np.int32
def test_load_kps(self): def test_load_kps(self):
transform = LoadAnnotations( transform = LoadAnnotations(
...@@ -97,6 +99,7 @@ class TestLoadAnnotations: ...@@ -97,6 +99,7 @@ class TestLoadAnnotations:
assert 'gt_keypoints' in results assert 'gt_keypoints' in results
assert (results['gt_keypoints'] == np.array([[[1, 2, 3]], assert (results['gt_keypoints'] == np.array([[[1, 2, 3]],
[[4, 5, 6]]])).all() [[4, 5, 6]]])).all()
assert results['gt_keypoints'].dtype == np.float32
def test_load_seg_map(self): def test_load_seg_map(self):
transform = LoadAnnotations( transform = LoadAnnotations(
...@@ -108,6 +111,7 @@ class TestLoadAnnotations: ...@@ -108,6 +111,7 @@ class TestLoadAnnotations:
results = transform(copy.deepcopy(self.results)) results = transform(copy.deepcopy(self.results))
assert 'gt_seg_map' in results assert 'gt_seg_map' in results
assert results['gt_seg_map'].shape[:2] == (300, 400) assert results['gt_seg_map'].shape[:2] == (300, 400)
assert results['gt_seg_map'].dtype == np.uint8
def test_repr(self): def test_repr(self):
transform = LoadAnnotations( transform = LoadAnnotations(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment