Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c66da5e8
Unverified
Commit
c66da5e8
authored
Apr 22, 2022
by
vfdev
Committed by
GitHub
Apr 22, 2022
Browse files
Added `crop_segmentation_mask` op (#5851)
* Added `crop_segmentation_mask` op * Fixed failed mypy
parent
ca265374
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
0 deletions
+60
-0
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+55
-0
torchvision/prototype/transforms/functional/__init__.py
torchvision/prototype/transforms/functional/__init__.py
+1
-0
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+4
-0
No files found.
test/test_prototype_transforms_functional.py
View file @
c66da5e8
...
@@ -332,6 +332,20 @@ def crop_bounding_box():
...
@@ -332,6 +332,20 @@ def crop_bounding_box():
)
)
@
register_kernel_info_from_sample_inputs_fn
def
crop_segmentation_mask
():
for
mask
,
top
,
left
,
height
,
width
in
itertools
.
product
(
make_segmentation_masks
(),
[
-
8
,
0
,
9
],
[
-
8
,
0
,
9
],
[
12
,
20
],
[
12
,
20
]
):
yield
SampleInput
(
mask
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"kernel"
,
"kernel"
,
[
[
...
@@ -860,3 +874,44 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
...
@@ -860,3 +874,44 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
)
)
torch
.
testing
.
assert_close
(
output_boxes
.
tolist
(),
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_boxes
.
tolist
(),
expected_bboxes
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"top, left, height, width"
,
[
[
4
,
6
,
30
,
40
],
[
-
8
,
6
,
70
,
40
],
[
-
8
,
-
6
,
70
,
8
],
],
)
def
test_correctness_crop_segmentation_mask
(
device
,
top
,
left
,
height
,
width
):
def
_compute_expected_mask
(
mask
,
top_
,
left_
,
height_
,
width_
):
h
,
w
=
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]
if
top_
>=
0
and
left_
>=
0
and
top_
+
height_
<
h
and
left_
+
width_
<
w
:
expected
=
mask
[...,
top_
:
top_
+
height_
,
left_
:
left_
+
width_
]
else
:
# Create output mask
expected_shape
=
mask
.
shape
[:
-
2
]
+
(
height_
,
width_
)
expected
=
torch
.
zeros
(
expected_shape
,
device
=
mask
.
device
,
dtype
=
mask
.
dtype
)
out_y1
=
abs
(
top_
)
if
top_
<
0
else
0
out_y2
=
h
-
top_
if
top_
+
height_
>=
h
else
height_
out_x1
=
abs
(
left_
)
if
left_
<
0
else
0
out_x2
=
w
-
left_
if
left_
+
width_
>=
w
else
width_
in_y1
=
0
if
top_
<
0
else
top_
in_y2
=
h
if
top_
+
height_
>=
h
else
top_
+
height_
in_x1
=
0
if
left_
<
0
else
left_
in_x2
=
w
if
left_
+
width_
>=
w
else
left_
+
width_
# Paste input mask into output
expected
[...,
out_y1
:
out_y2
,
out_x1
:
out_x2
]
=
mask
[...,
in_y1
:
in_y2
,
in_x1
:
in_x2
]
return
expected
for
mask
in
make_segmentation_masks
():
if
mask
.
device
!=
torch
.
device
(
device
):
mask
=
mask
.
to
(
device
)
output_mask
=
F
.
crop_segmentation_mask
(
mask
,
top
,
left
,
height
,
width
)
expected_mask
=
_compute_expected_mask
(
mask
,
top
,
left
,
height
,
width
)
torch
.
testing
.
assert_close
(
output_mask
,
expected_mask
)
torchvision/prototype/transforms/functional/__init__.py
View file @
c66da5e8
...
@@ -63,6 +63,7 @@ from ._geometry import (
...
@@ -63,6 +63,7 @@ from ._geometry import (
crop_bounding_box
,
crop_bounding_box
,
crop_image_tensor
,
crop_image_tensor
,
crop_image_pil
,
crop_image_pil
,
crop_segmentation_mask
,
perspective_image_tensor
,
perspective_image_tensor
,
perspective_image_pil
,
perspective_image_pil
,
vertical_flip_image_tensor
,
vertical_flip_image_tensor
,
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
c66da5e8
...
@@ -440,6 +440,10 @@ def crop_bounding_box(
...
@@ -440,6 +440,10 @@ def crop_bounding_box(
).
view
(
shape
)
).
view
(
shape
)
def
crop_segmentation_mask
(
img
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
return
crop_image_tensor
(
img
,
top
,
left
,
height
,
width
)
def
perspective_image_tensor
(
def
perspective_image_tensor
(
img
:
torch
.
Tensor
,
img
:
torch
.
Tensor
,
perspective_coeffs
:
List
[
float
],
perspective_coeffs
:
List
[
float
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment