Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7ba3d7e2
Unverified
Commit
7ba3d7e2
authored
Sep 25, 2023
by
Philip Meier
Committed by
GitHub
Sep 25, 2023
Browse files
port tests for transforms.RandomZoomOut (#7975)
parent
c3aee873
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
30 deletions
+65
-30
test/test_transforms_v2.py
test/test_transforms_v2.py
+0
-29
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+64
-0
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+1
-1
No files found.
test/test_transforms_v2.py
View file @
7ba3d7e2
...
@@ -136,7 +136,6 @@ class TestSmoke:
...
@@ -136,7 +136,6 @@ class TestSmoke:
(
transforms
.
RandomRotation
(
degrees
=
30
),
None
),
(
transforms
.
RandomRotation
(
degrees
=
30
),
None
),
(
transforms
.
RandomShortestSize
(
min_size
=
10
,
antialias
=
True
),
None
),
(
transforms
.
RandomShortestSize
(
min_size
=
10
,
antialias
=
True
),
None
),
(
transforms
.
RandomVerticalFlip
(
p
=
1.0
),
None
),
(
transforms
.
RandomVerticalFlip
(
p
=
1.0
),
None
),
(
transforms
.
RandomZoomOut
(
p
=
1.0
),
None
),
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
),
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
),
antialias
=
True
),
None
),
(
transforms
.
ClampBoundingBoxes
(),
None
),
(
transforms
.
ClampBoundingBoxes
(),
None
),
...
@@ -388,34 +387,6 @@ def test_pure_tensor_heuristic(flat_inputs):
...
@@ -388,34 +387,6 @@ def test_pure_tensor_heuristic(flat_inputs):
assert
transform
.
was_applied
(
output
,
input
)
assert
transform
.
was_applied
(
output
,
input
)
class
TestRandomZoomOut
:
def
test_assertions
(
self
):
with
pytest
.
raises
(
TypeError
,
match
=
"Got inappropriate fill arg"
):
transforms
.
RandomZoomOut
(
fill
=
"abc"
)
with
pytest
.
raises
(
TypeError
,
match
=
"should be a sequence of length"
):
transforms
.
RandomZoomOut
(
0
,
side_range
=
0
)
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid canvas side range"
):
transforms
.
RandomZoomOut
(
0
,
side_range
=
[
4.0
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
0
,
[
1
,
2
,
3
],
(
2
,
3
,
4
)])
@
pytest
.
mark
.
parametrize
(
"side_range"
,
[(
1.0
,
4.0
),
[
2.0
,
5.0
]])
def
test__get_params
(
self
,
fill
,
side_range
):
transform
=
transforms
.
RandomZoomOut
(
fill
=
fill
,
side_range
=
side_range
)
h
,
w
=
size
=
(
24
,
32
)
image
=
make_image
(
size
)
params
=
transform
.
_get_params
([
image
])
assert
len
(
params
[
"padding"
])
==
4
assert
0
<=
params
[
"padding"
][
0
]
<=
(
side_range
[
1
]
-
1
)
*
w
assert
0
<=
params
[
"padding"
][
1
]
<=
(
side_range
[
1
]
-
1
)
*
h
assert
0
<=
params
[
"padding"
][
2
]
<=
(
side_range
[
1
]
-
1
)
*
w
assert
0
<=
params
[
"padding"
][
3
]
<=
(
side_range
[
1
]
-
1
)
*
h
class
TestElasticTransform
:
class
TestElasticTransform
:
def
test_assertions
(
self
):
def
test_assertions
(
self
):
...
...
test/test_transforms_v2_refactored.py
View file @
7ba3d7e2
...
@@ -4000,3 +4000,67 @@ class TestRgbToGrayscale:
...
@@ -4000,3 +4000,67 @@ class TestRgbToGrayscale:
expected
=
F
.
to_image
(
F
.
rgb_to_grayscale
(
F
.
to_pil_image
(
image
),
num_output_channels
=
num_input_channels
))
expected
=
F
.
to_image
(
F
.
rgb_to_grayscale
(
F
.
to_pil_image
(
image
),
num_output_channels
=
num_input_channels
))
assert_equal
(
actual
,
expected
,
rtol
=
0
,
atol
=
1
)
assert_equal
(
actual
,
expected
,
rtol
=
0
,
atol
=
1
)
class
TestRandomZoomOut
:
# Tests are light because this largely relies on the already tested `pad` kernels.
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_detection_mask
,
make_video
,
],
)
def
test_transform
(
self
,
make_input
):
check_transform
(
transforms
.
RandomZoomOut
(
p
=
1
),
make_input
())
def
test_transform_error
(
self
):
for
side_range
in
[
None
,
1
,
[
1
,
2
,
3
]]:
with
pytest
.
raises
(
ValueError
if
isinstance
(
side_range
,
list
)
else
TypeError
,
match
=
"should be a sequence of length 2"
):
transforms
.
RandomZoomOut
(
side_range
=
side_range
)
for
side_range
in
[[
0.5
,
1.5
],
[
2.0
,
1.0
]]:
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid side range"
):
transforms
.
RandomZoomOut
(
side_range
=
side_range
)
@
pytest
.
mark
.
parametrize
(
"side_range"
,
[(
1.0
,
4.0
),
[
2.0
,
5.0
]])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_boxes
,
make_segmentation_mask
,
make_detection_mask
,
make_video
,
],
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_transform_params_correctness
(
self
,
side_range
,
make_input
,
device
):
if
make_input
is
make_image_pil
and
device
!=
"cpu"
:
pytest
.
skip
(
"PIL image tests with parametrization device!='cpu' will degenerate to that anyway."
)
transform
=
transforms
.
RandomZoomOut
(
side_range
=
side_range
)
input
=
make_input
()
height
,
width
=
F
.
get_size
(
input
)
params
=
transform
.
_get_params
([
input
])
assert
"padding"
in
params
padding
=
params
[
"padding"
]
assert
len
(
padding
)
==
4
assert
0
<=
padding
[
0
]
<=
(
side_range
[
1
]
-
1
)
*
width
assert
0
<=
padding
[
1
]
<=
(
side_range
[
1
]
-
1
)
*
height
assert
0
<=
padding
[
2
]
<=
(
side_range
[
1
]
-
1
)
*
width
assert
0
<=
padding
[
3
]
<=
(
side_range
[
1
]
-
1
)
*
height
torchvision/transforms/v2/_geometry.py
View file @
7ba3d7e2
...
@@ -546,7 +546,7 @@ class RandomZoomOut(_RandomApplyTransform):
...
@@ -546,7 +546,7 @@ class RandomZoomOut(_RandomApplyTransform):
self
.
side_range
=
side_range
self
.
side_range
=
side_range
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
raise
ValueError
(
f
"Invalid
canvas
side range provided
{
side_range
}
."
)
raise
ValueError
(
f
"Invalid side range provided
{
side_range
}
."
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
orig_h
,
orig_w
=
query_size
(
flat_inputs
)
orig_h
,
orig_w
=
query_size
(
flat_inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment