Commit 88f3cc3f authored by liukuikun's avatar liukuikun Committed by zhouzaida
Browse files

replace height, width with img_shape

parent 772069f2
...@@ -17,10 +17,8 @@ class LoadImageFromFile(BaseTransform): ...@@ -17,10 +17,8 @@ class LoadImageFromFile(BaseTransform):
Modified Keys: Modified Keys:
- img - img
- width - img_shape
- height - ori_shape
- ori_width
- ori_height
Args: Args:
to_float32 (bool): Whether to convert the loaded image to a float32 to_float32 (bool): Whether to convert the loaded image to a float32
...@@ -68,11 +66,8 @@ class LoadImageFromFile(BaseTransform): ...@@ -68,11 +66,8 @@ class LoadImageFromFile(BaseTransform):
img = img.astype(np.float32) img = img.astype(np.float32)
results['img'] = img results['img'] = img
height, width = img.shape[:2] results['img_shape'] = img.shape[:2]
results['height'] = height results['ori_shape'] = img.shape[:2]
results['width'] = width
results['ori_height'] = height
results['ori_width'] = width
return results return results
def __repr__(self): def __repr__(self):
...@@ -123,7 +118,7 @@ class LoadAnnotations(BaseTransform): ...@@ -123,7 +118,7 @@ class LoadAnnotations(BaseTransform):
# In (x1, y1, x2, y2) order, float type. N is the number of bboxes # In (x1, y1, x2, y2) order, float type. N is the number of bboxes
# in np.float32 # in np.float32
'gt_bboxes': np.ndarray(N, 4) 'gt_bboxes': np.ndarray(N, 4)
# In np.int32 type. # In np.int64 type.
'gt_bboxes_labels': np.ndarray(N, ) 'gt_bboxes_labels': np.ndarray(N, )
# In uint8 type. # In uint8 type.
'gt_seg_map': np.ndarray (H, W) 'gt_seg_map': np.ndarray (H, W)
...@@ -144,7 +139,7 @@ class LoadAnnotations(BaseTransform): ...@@ -144,7 +139,7 @@ class LoadAnnotations(BaseTransform):
Added Keys: Added Keys:
- gt_bboxes (np.float32) - gt_bboxes (np.float32)
- gt_bboxes_labels (np.int32) - gt_bboxes_labels (np.int64)
- gt_seg_map (np.uint8) - gt_seg_map (np.uint8)
- gt_keypoints (np.float32) - gt_keypoints (np.float32)
...@@ -211,7 +206,7 @@ class LoadAnnotations(BaseTransform): ...@@ -211,7 +206,7 @@ class LoadAnnotations(BaseTransform):
for instance in results['instances']: for instance in results['instances']:
gt_bboxes_labels.append(instance['bbox_label']) gt_bboxes_labels.append(instance['bbox_label'])
results['gt_bboxes_labels'] = np.array( results['gt_bboxes_labels'] = np.array(
gt_bboxes_labels, dtype=np.int32) gt_bboxes_labels, dtype=np.int64)
def _load_seg_map(self, results: dict) -> None: def _load_seg_map(self, results: dict) -> None:
"""Private function to load semantic segmentation annotations. """Private function to load semantic segmentation annotations.
......
...@@ -286,8 +286,7 @@ class Pad(BaseTransform): ...@@ -286,8 +286,7 @@ class Pad(BaseTransform):
- img - img
- gt_seg_map - gt_seg_map
- height - img_shape
- width
Added Keys: Added Keys:
...@@ -382,8 +381,7 @@ class Pad(BaseTransform): ...@@ -382,8 +381,7 @@ class Pad(BaseTransform):
results['pad_shape'] = padded_img.shape results['pad_shape'] = padded_img.shape
results['pad_fixed_size'] = self.size results['pad_fixed_size'] = self.size
results['pad_size_divisor'] = self.size_divisor results['pad_size_divisor'] = self.size_divisor
results['height'] = padded_img.shape[0] results['img_shape'] = padded_img.shape[:2]
results['width'] = padded_img.shape[1]
def _pad_seg(self, results: dict) -> None: def _pad_seg(self, results: dict) -> None:
"""Pad semantic segmentation map according to """Pad semantic segmentation map according to
......
...@@ -18,10 +18,8 @@ class TestLoadImageFromFile: ...@@ -18,10 +18,8 @@ class TestLoadImageFromFile:
assert results['img_path'] == osp.join(data_prefix, 'color.jpg') assert results['img_path'] == osp.join(data_prefix, 'color.jpg')
assert results['img'].shape == (300, 400, 3) assert results['img'].shape == (300, 400, 3)
assert results['img'].dtype == np.uint8 assert results['img'].dtype == np.uint8
assert results['height'] == 300 assert results['img_shape'] == (300, 400)
assert results['width'] == 400 assert results['ori_shape'] == (300, 400)
assert results['ori_height'] == 300
assert results['ori_width'] == 400
assert repr(transform) == transform.__class__.__name__ + \ assert repr(transform) == transform.__class__.__name__ + \
"(to_float32=False, color_type='color', " + \ "(to_float32=False, color_type='color', " + \
"imdecode_backend='cv2', file_client_args={'backend': 'disk'})" "imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
...@@ -86,7 +84,7 @@ class TestLoadAnnotations: ...@@ -86,7 +84,7 @@ class TestLoadAnnotations:
results = transform(copy.deepcopy(self.results)) results = transform(copy.deepcopy(self.results))
assert 'gt_bboxes_labels' in results assert 'gt_bboxes_labels' in results
assert (results['gt_bboxes_labels'] == np.array([1, 2])).all() assert (results['gt_bboxes_labels'] == np.array([1, 2])).all()
assert results['gt_bboxes_labels'].dtype == np.int32 assert results['gt_bboxes_labels'].dtype == np.int64
def test_load_kps(self): def test_load_kps(self):
transform = LoadAnnotations( transform = LoadAnnotations(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment