"vscode:/vscode.git/clone" did not exist on "1d34a19710c20bb27e1311326153c804903eb10f"
Commit 5af6c12b authored by liukuikun's avatar liukuikun Committed by zhouzaida
Browse files

[Fix] Pad boader fix (#1757)

* [Fix] expand pad_val'dim to be same with image

* fix comment
parent 53070ebc
...@@ -287,7 +287,7 @@ class Pad(BaseTransform): ...@@ -287,7 +287,7 @@ class Pad(BaseTransform):
Args: Args:
size (tuple, optional): Fixed padding size. size (tuple, optional): Fixed padding size.
Expected padding shape (h, w)Defaults to None. Expected padding shape (w, h). Defaults to None.
size_divisor (int, optional): The divisor of padded size. Defaults to size_divisor (int, optional): The divisor of padded size. Defaults to
None. None.
pad_to_square (bool): Whether to pad the image into a square. pad_to_square (bool): Whether to pad the image into a square.
...@@ -354,6 +354,8 @@ class Pad(BaseTransform): ...@@ -354,6 +354,8 @@ class Pad(BaseTransform):
size = (pad_h, pad_w) size = (pad_h, pad_w)
elif self.size is not None: elif self.size is not None:
size = self.size[::-1] size = self.size[::-1]
if isinstance(pad_val, int) and results['img'].ndim == 3:
pad_val = tuple([pad_val for _ in range(results['img'].shape[2])])
padded_img = mmcv.impad( padded_img = mmcv.impad(
results['img'], results['img'],
shape=size, shape=size,
...@@ -372,7 +374,9 @@ class Pad(BaseTransform): ...@@ -372,7 +374,9 @@ class Pad(BaseTransform):
``results['pad_shape']``.""" ``results['pad_shape']``."""
if results.get('gt_semantic_seg', None) is not None: if results.get('gt_semantic_seg', None) is not None:
pad_val = self.pad_val.get('seg', 255) pad_val = self.pad_val.get('seg', 255)
if isinstance(pad_val, int) and results['gt_semantic_seg'].ndim == 3:
pad_val = tuple(
[pad_val for _ in range(results['gt_semantic_seg'].shape[2])])
results['gt_semantic_seg'] = mmcv.impad( results['gt_semantic_seg'] = mmcv.impad(
results['gt_semantic_seg'], results['gt_semantic_seg'],
shape=results['pad_shape'][:2], shape=results['pad_shape'][:2],
......
...@@ -171,12 +171,50 @@ class TestPad: ...@@ -171,12 +171,50 @@ class TestPad:
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert (results['img'] == np.ones((1333, 1333, 3))).all() assert (results['img'] == np.ones((1333, 1333, 3))).all()
# test pad_val # test pad_val is dict
new_img = np.zeros((1333, 800, 3)) # test rgb image, size=(2000, 2000)
trans = Pad(
size=(2000, 2000),
pad_val=dict(img=(12, 12, 12), seg=(10, 10, 10)))
results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000, :] == 12).all()
assert (results['gt_semantic_seg'][1333:2000, 800:2000, :] == 10).all()
trans = Pad(size=(2000, 2000), pad_val=dict(img=(12, 12, 12)))
results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000, :] == 12).all()
assert (results['gt_semantic_seg'][1333:2000,
800:2000, :] == 255).all()
# test rgb image, pad_to_square=True
trans = Pad(
pad_to_square=True,
pad_val=dict(img=(12, 12, 12), seg=(10, 10, 10)))
results = trans(copy.deepcopy(data_info))
assert (results['img'][:, 800:1333, :] == 12).all()
assert (results['gt_semantic_seg'][:, 800:1333, :] == 10).all()
trans = Pad(pad_to_square=True, pad_val=dict(img=(12, 12, 12)))
results = trans(copy.deepcopy(data_info))
assert (results['img'][:, 800:1333, :] == 12).all()
assert (results['gt_semantic_seg'][:, 800:1333, :] == 255).all()
# test pad_val is int
# test rgb image
trans = Pad(size=(2000, 2000), pad_val=12)
results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000, :] == 12).all()
assert (results['gt_semantic_seg'][1333:2000,
800:2000, :] == 255).all()
# test gray image
new_img = np.random.random((1333, 800))
data_info['img'] = new_img data_info['img'] = new_img
trans = Pad(pad_to_square=True, pad_val=0) new_semantic_seg = np.random.random((1333, 800))
data_info['gt_semantic_seg'] = new_semantic_seg
trans = Pad(size=(2000, 2000), pad_val=12)
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert (results['img'] == np.zeros((1333, 1333, 3))).all() assert (results['img'][1333:2000, 800:2000] == 12).all()
assert (results['gt_semantic_seg'][1333:2000, 800:2000] == 255).all()
def test_repr(self): def test_repr(self):
trans = Pad(pad_to_square=True, size_divisor=11, padding_mode='edge') trans = Pad(pad_to_square=True, size_divisor=11, padding_mode='edge')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment