"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "5e75c8e803e623481e2e76ba93444301d498be54"
Unverified Commit c66da5e8 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added `crop_segmentation_mask` op (#5851)

* Added `crop_segmentation_mask` op

* Fixed failed mypy
parent ca265374
...@@ -332,6 +332,20 @@ def crop_bounding_box(): ...@@ -332,6 +332,20 @@ def crop_bounding_box():
) )
@register_kernel_info_from_sample_inputs_fn
def crop_segmentation_mask():
for mask, top, left, height, width in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
):
yield SampleInput(
mask,
top=top,
left=left,
height=height,
width=width,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kernel", "kernel",
[ [
...@@ -860,3 +874,44 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte ...@@ -860,3 +874,44 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
) )
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"top, left, height, width",
[
[4, 6, 30, 40],
[-8, 6, 70, 40],
[-8, -6, 70, 8],
],
)
def test_correctness_crop_segmentation_mask(device, top, left, height, width):
def _compute_expected_mask(mask, top_, left_, height_, width_):
h, w = mask.shape[-2], mask.shape[-1]
if top_ >= 0 and left_ >= 0 and top_ + height_ < h and left_ + width_ < w:
expected = mask[..., top_ : top_ + height_, left_ : left_ + width_]
else:
# Create output mask
expected_shape = mask.shape[:-2] + (height_, width_)
expected = torch.zeros(expected_shape, device=mask.device, dtype=mask.dtype)
out_y1 = abs(top_) if top_ < 0 else 0
out_y2 = h - top_ if top_ + height_ >= h else height_
out_x1 = abs(left_) if left_ < 0 else 0
out_x2 = w - left_ if left_ + width_ >= w else width_
in_y1 = 0 if top_ < 0 else top_
in_y2 = h if top_ + height_ >= h else top_ + height_
in_x1 = 0 if left_ < 0 else left_
in_x2 = w if left_ + width_ >= w else left_ + width_
# Paste input mask into output
expected[..., out_y1:out_y2, out_x1:out_x2] = mask[..., in_y1:in_y2, in_x1:in_x2]
return expected
for mask in make_segmentation_masks():
if mask.device != torch.device(device):
mask = mask.to(device)
output_mask = F.crop_segmentation_mask(mask, top, left, height, width)
expected_mask = _compute_expected_mask(mask, top, left, height, width)
torch.testing.assert_close(output_mask, expected_mask)
...@@ -63,6 +63,7 @@ from ._geometry import ( ...@@ -63,6 +63,7 @@ from ._geometry import (
crop_bounding_box, crop_bounding_box,
crop_image_tensor, crop_image_tensor,
crop_image_pil, crop_image_pil,
crop_segmentation_mask,
perspective_image_tensor, perspective_image_tensor,
perspective_image_pil, perspective_image_pil,
vertical_flip_image_tensor, vertical_flip_image_tensor,
......
...@@ -440,6 +440,10 @@ def crop_bounding_box( ...@@ -440,6 +440,10 @@ def crop_bounding_box(
).view(shape) ).view(shape)
def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(img, top, left, height, width)
def perspective_image_tensor( def perspective_image_tensor(
img: torch.Tensor, img: torch.Tensor,
perspective_coeffs: List[float], perspective_coeffs: List[float],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment