Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b96d381c
Unverified
Commit
b96d381c
authored
May 24, 2021
by
Nicolas Hug
Committed by
GitHub
May 24, 2021
Browse files
Use torch.testing.assert_close in test_functional_tensor (#3876)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
963d432c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
42 deletions
+42
-42
test/test_functional_tensor.py
test/test_functional_tensor.py
+42
-42
No files found.
test/test_functional_tensor.py
View file @
b96d381c
...
...
@@ -15,6 +15,7 @@ import torchvision.transforms as T
from
torchvision.transforms
import
InterpolationMode
from
common_utils
import
TransformsTester
,
cpu_and_gpu
,
needs_cuda
from
_assert_utils
import
assert_equal
from
typing
import
Dict
,
List
,
Sequence
,
Tuple
...
...
@@ -39,13 +40,13 @@ class Tester(TransformsTester):
for
i
in
range
(
len
(
batch_tensors
)):
img_tensor
=
batch_tensors
[
i
,
...]
transformed_img
=
fn
(
img_tensor
,
**
fn_kwargs
)
self
.
assert
True
(
transformed_img
.
equal
(
transformed_batch
[
i
,
...])
)
assert
_equal
(
transformed_img
,
transformed_batch
[
i
,
...])
if
scripted_fn_atol
>=
0
:
scripted_fn
=
torch
.
jit
.
script
(
fn
)
# scriptable function test
s_transformed_batch
=
scripted_fn
(
batch_tensors
,
**
fn_kwargs
)
self
.
assert
Tru
e
(
transformed_batch
.
allclose
(
s_transformed_batch
,
atol
=
scripted_fn_atol
)
)
torch
.
testing
.
assert
_clos
e
(
transformed_batch
,
s_transformed_batch
,
rtol
=
1e-5
,
atol
=
scripted_fn_atol
)
def
test_assert_image_tensor
(
self
):
shape
=
(
100
,)
...
...
@@ -79,7 +80,7 @@ class Tester(TransformsTester):
# scriptable function test
vflipped_img_script
=
script_vflip
(
img_tensor
)
self
.
assert
True
(
vflipped_img
.
equal
(
vflipped_img_script
)
)
assert
_equal
(
vflipped_img
,
vflipped_img_script
)
batch_tensors
=
self
.
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
self
.
_test_fn_on_batch
(
batch_tensors
,
F
.
vflip
)
...
...
@@ -94,7 +95,7 @@ class Tester(TransformsTester):
# scriptable function test
hflipped_img_script
=
script_hflip
(
img_tensor
)
self
.
assert
True
(
hflipped_img
.
equal
(
hflipped_img_script
)
)
assert
_equal
(
hflipped_img
,
hflipped_img_script
)
batch_tensors
=
self
.
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
self
.
_test_fn_on_batch
(
batch_tensors
,
F
.
hflip
)
...
...
@@ -140,11 +141,10 @@ class Tester(TransformsTester):
for
h1
,
s1
,
v1
in
zip
(
h
,
s
,
v
):
rgb
.
append
(
colorsys
.
hsv_to_rgb
(
h1
,
s1
,
v1
))
colorsys_img
=
torch
.
tensor
(
rgb
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
max_diff
=
(
ft_img
-
colorsys_img
).
abs
().
max
()
self
.
assertLess
(
max_diff
,
1e-5
)
torch
.
testing
.
assert_close
(
ft_img
,
colorsys_img
,
rtol
=
0.0
,
atol
=
1e-5
)
s_rgb_img
=
scripted_fn
(
hsv_img
)
self
.
assert
Tru
e
(
rgb_img
.
allclose
(
s_rgb_img
)
)
torch
.
testing
.
assert
_clos
e
(
rgb_img
,
s_rgb_img
)
batch_tensors
=
self
.
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
self
.
device
).
float
()
self
.
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_hsv2rgb
)
...
...
@@ -177,7 +177,7 @@ class Tester(TransformsTester):
self
.
assertLess
(
max_diff
,
1e-5
)
s_hsv_img
=
scripted_fn
(
rgb_img
)
self
.
assert
Tru
e
(
hsv_img
.
allclose
(
s_hsv_img
,
atol
=
1e-7
)
)
torch
.
testing
.
assert
_clos
e
(
hsv_img
,
s_hsv_img
,
rtol
=
1e-5
,
atol
=
1e-7
)
batch_tensors
=
self
.
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
self
.
device
).
float
()
self
.
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_rgb2hsv
)
...
...
@@ -194,7 +194,7 @@ class Tester(TransformsTester):
self
.
approxEqualTensorToPIL
(
gray_tensor
.
float
(),
gray_pil_image
,
tol
=
1.0
+
1e-10
,
agg_method
=
"max"
)
s_gray_tensor
=
script_rgb_to_grayscale
(
img_tensor
,
num_output_channels
=
num_output_channels
)
self
.
assert
True
(
s_gray_tensor
.
equal
(
gray_tensor
)
)
assert
_equal
(
s_gray_tensor
,
gray_tensor
)
batch_tensors
=
self
.
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
self
.
_test_fn_on_batch
(
batch_tensors
,
F
.
rgb_to_grayscale
,
num_output_channels
=
num_output_channels
)
...
...
@@ -240,12 +240,12 @@ class Tester(TransformsTester):
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
self
.
assert
True
(
true_transformed_img
.
equal
(
transformed_img
)
)
assert
_equal
(
true_transformed_img
,
transformed_img
)
# scriptable function test
s_tuple_transformed_batches
=
script_five_crop
(
batch_tensors
,
[
10
,
11
])
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
self
.
assert
True
(
transformed_batch
.
equal
(
s_transformed_batch
)
)
assert
_equal
(
transformed_batch
,
s_transformed_batch
)
def
test_ten_crop
(
self
):
script_ten_crop
=
torch
.
jit
.
script
(
F
.
ten_crop
)
...
...
@@ -272,12 +272,12 @@ class Tester(TransformsTester):
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
self
.
assert
True
(
true_transformed_img
.
equal
(
transformed_img
)
)
assert
_equal
(
true_transformed_img
,
transformed_img
)
# scriptable function test
s_tuple_transformed_batches
=
script_ten_crop
(
batch_tensors
,
[
10
,
11
])
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
self
.
assert
True
(
transformed_batch
.
equal
(
s_transformed_batch
)
)
assert
_equal
(
transformed_batch
,
s_transformed_batch
)
def
test_pad
(
self
):
script_fn
=
torch
.
jit
.
script
(
F
.
pad
)
...
...
@@ -320,7 +320,7 @@ class Tester(TransformsTester):
else
:
script_pad
=
pad
pad_tensor_script
=
script_fn
(
tensor
,
script_pad
,
**
kwargs
)
self
.
assert
True
(
pad_tensor
.
equal
(
pad_tensor_script
)
,
msg
=
"{}, {}"
.
format
(
pad
,
kwargs
))
assert
_equal
(
pad_tensor
,
pad_tensor_script
,
msg
=
"{}, {}"
.
format
(
pad
,
kwargs
))
self
.
_test_fn_on_batch
(
batch_tensors
,
F
.
pad
,
padding
=
script_pad
,
**
kwargs
)
...
...
@@ -348,9 +348,10 @@ class Tester(TransformsTester):
resized_tensor
=
F
.
resize
(
tensor
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
)
resized_pil_img
=
F
.
resize
(
pil_img
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
)
self
.
assertEqual
(
resized_tensor
.
size
()[
1
:],
resized_pil_img
.
size
[::
-
1
],
msg
=
"{}, {}"
.
format
(
size
,
interpolation
)
assert_equal
(
resized_tensor
.
size
()[
1
:],
resized_pil_img
.
size
[::
-
1
],
msg
=
"{}, {}"
.
format
(
size
,
interpolation
),
)
if
interpolation
not
in
[
NEAREST
,
]:
...
...
@@ -374,7 +375,7 @@ class Tester(TransformsTester):
resize_result
=
script_fn
(
tensor
,
size
=
script_size
,
interpolation
=
interpolation
,
max_size
=
max_size
)
self
.
assert
True
(
resized_tensor
.
equal
(
resize_result
)
,
msg
=
"{}, {}"
.
format
(
size
,
interpolation
))
assert
_equal
(
resized_tensor
,
resize_result
,
msg
=
"{}, {}"
.
format
(
size
,
interpolation
))
self
.
_test_fn_on_batch
(
batch_tensors
,
F
.
resize
,
size
=
script_size
,
interpolation
=
interpolation
,
max_size
=
max_size
...
...
@@ -384,7 +385,7 @@ class Tester(TransformsTester):
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument interpolation should be of type InterpolationMode"
):
res1
=
F
.
resize
(
tensor
,
size
=
32
,
interpolation
=
2
)
res2
=
F
.
resize
(
tensor
,
size
=
32
,
interpolation
=
BILINEAR
)
self
.
assert
True
(
res1
.
equal
(
res2
)
)
assert
_equal
(
res1
,
res2
)
for
img
in
(
tensor
,
pil_img
):
exp_msg
=
"max_size should only be passed if size specifies the length of the smaller edge"
...
...
@@ -400,15 +401,17 @@ class Tester(TransformsTester):
for
mode
in
[
NEAREST
,
BILINEAR
,
BICUBIC
]:
out_tensor
=
F
.
resized_crop
(
tensor
,
top
=
0
,
left
=
0
,
height
=
26
,
width
=
36
,
size
=
[
26
,
36
],
interpolation
=
mode
)
self
.
assert
True
(
tensor
.
equal
(
out_tensor
)
,
msg
=
"{} vs {}"
.
format
(
out_tensor
[
0
,
:
5
,
:
5
],
tensor
[
0
,
:
5
,
:
5
]))
assert
_equal
(
tensor
,
out_tensor
,
msg
=
"{} vs {}"
.
format
(
out_tensor
[
0
,
:
5
,
:
5
],
tensor
[
0
,
:
5
,
:
5
]))
# 2) resize by half and crop a TL corner
tensor
,
_
=
self
.
_create_data
(
26
,
36
,
device
=
self
.
device
)
out_tensor
=
F
.
resized_crop
(
tensor
,
top
=
0
,
left
=
0
,
height
=
20
,
width
=
30
,
size
=
[
10
,
15
],
interpolation
=
NEAREST
)
expected_out_tensor
=
tensor
[:,
:
20
:
2
,
:
30
:
2
]
self
.
assertTrue
(
expected_out_tensor
.
equal
(
out_tensor
),
msg
=
"{} vs {}"
.
format
(
expected_out_tensor
[
0
,
:
10
,
:
10
],
out_tensor
[
0
,
:
10
,
:
10
])
assert_equal
(
expected_out_tensor
,
out_tensor
,
check_stride
=
False
,
msg
=
"{} vs {}"
.
format
(
expected_out_tensor
[
0
,
:
10
,
:
10
],
out_tensor
[
0
,
:
10
,
:
10
]),
)
batch_tensors
=
self
.
_create_data_batch
(
26
,
36
,
num_samples
=
4
,
device
=
self
.
device
)
...
...
@@ -420,15 +423,11 @@ class Tester(TransformsTester):
# 1) identity map
out_tensor
=
F
.
affine
(
tensor
,
angle
=
0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
interpolation
=
NEAREST
)
self
.
assertTrue
(
tensor
.
equal
(
out_tensor
),
msg
=
"{} vs {}"
.
format
(
out_tensor
[
0
,
:
5
,
:
5
],
tensor
[
0
,
:
5
,
:
5
])
)
assert_equal
(
tensor
,
out_tensor
,
msg
=
"{} vs {}"
.
format
(
out_tensor
[
0
,
:
5
,
:
5
],
tensor
[
0
,
:
5
,
:
5
]))
out_tensor
=
scripted_affine
(
tensor
,
angle
=
0
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
interpolation
=
NEAREST
)
self
.
assertTrue
(
tensor
.
equal
(
out_tensor
),
msg
=
"{} vs {}"
.
format
(
out_tensor
[
0
,
:
5
,
:
5
],
tensor
[
0
,
:
5
,
:
5
])
)
assert_equal
(
tensor
,
out_tensor
,
msg
=
"{} vs {}"
.
format
(
out_tensor
[
0
,
:
5
,
:
5
],
tensor
[
0
,
:
5
,
:
5
]))
def
_test_affine_square_rotations
(
self
,
tensor
,
pil_img
,
scripted_affine
):
# 2) Test rotation
...
...
@@ -452,9 +451,11 @@ class Tester(TransformsTester):
tensor
,
angle
=
a
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
interpolation
=
NEAREST
)
if
true_tensor
is
not
None
:
self
.
assertTrue
(
true_tensor
.
equal
(
out_tensor
),
msg
=
"{}
\n
{} vs
\n
{}"
.
format
(
a
,
out_tensor
[
0
,
:
5
,
:
5
],
true_tensor
[
0
,
:
5
,
:
5
])
assert_equal
(
true_tensor
,
out_tensor
,
msg
=
"{}
\n
{} vs
\n
{}"
.
format
(
a
,
out_tensor
[
0
,
:
5
,
:
5
],
true_tensor
[
0
,
:
5
,
:
5
]),
check_stride
=
False
,
)
if
out_tensor
.
dtype
!=
torch
.
uint8
:
...
...
@@ -593,18 +594,19 @@ class Tester(TransformsTester):
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument resample is deprecated and will be removed"
):
res1
=
F
.
affine
(
tensor
,
45
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
resample
=
2
)
res2
=
F
.
affine
(
tensor
,
45
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
interpolation
=
BILINEAR
)
self
.
assert
True
(
res1
.
equal
(
res2
)
)
assert
_equal
(
res1
,
res2
)
# assert changed type warning
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument interpolation should be of type InterpolationMode"
):
res1
=
F
.
affine
(
tensor
,
45
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
interpolation
=
2
)
res2
=
F
.
affine
(
tensor
,
45
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
interpolation
=
BILINEAR
)
self
.
assert
True
(
res1
.
equal
(
res2
)
)
assert
_equal
(
res1
,
res2
)
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument fillcolor is deprecated and will be removed"
):
res1
=
F
.
affine
(
pil_img
,
45
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
fillcolor
=
10
)
res2
=
F
.
affine
(
pil_img
,
45
,
translate
=
[
0
,
0
],
scale
=
1.0
,
shear
=
[
0.0
,
0.0
],
fill
=
10
)
self
.
assertEqual
(
res1
,
res2
)
# we convert the PIL images to numpy as assert_equal doesn't work on PIL images.
assert_equal
(
np
.
asarray
(
res1
),
np
.
asarray
(
res2
))
def
_test_rotate_all_options
(
self
,
tensor
,
pil_img
,
scripted_rotate
,
centers
):
img_size
=
pil_img
.
size
...
...
@@ -682,13 +684,13 @@ class Tester(TransformsTester):
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument resample is deprecated and will be removed"
):
res1
=
F
.
rotate
(
tensor
,
45
,
resample
=
2
)
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
self
.
assert
True
(
res1
.
equal
(
res2
)
)
assert
_equal
(
res1
,
res2
)
# assert changed type warning
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument interpolation should be of type InterpolationMode"
):
res1
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
2
)
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
self
.
assert
True
(
res1
.
equal
(
res2
)
)
assert
_equal
(
res1
,
res2
)
def
test_gaussian_blur
(
self
):
small_image_tensor
=
torch
.
from_numpy
(
...
...
@@ -747,10 +749,8 @@ class Tester(TransformsTester):
for
fn
in
[
F
.
gaussian_blur
,
scripted_transform
]:
out
=
fn
(
tensor
,
kernel_size
=
ksize
,
sigma
=
sigma
)
self
.
assertEqual
(
true_out
.
shape
,
out
.
shape
,
msg
=
"{}, {}"
.
format
(
ksize
,
sigma
))
self
.
assertLessEqual
(
torch
.
max
(
true_out
.
float
()
-
out
.
float
()),
1.0
,
torch
.
testing
.
assert_close
(
out
,
true_out
,
rtol
=
0.0
,
atol
=
1.0
,
check_stride
=
False
,
msg
=
"{}, {}"
.
format
(
ksize
,
sigma
)
)
...
...
@@ -771,7 +771,7 @@ class CUDATester(Tester):
img_chan
=
torch
.
randint
(
0
,
256
,
size
=
size
).
to
(
'cpu'
)
scaled_cpu
=
F_t
.
_scale_channel
(
img_chan
)
scaled_cuda
=
F_t
.
_scale_channel
(
img_chan
.
to
(
'cuda'
))
self
.
assert
True
(
scaled_cpu
.
equal
(
scaled_cuda
.
to
(
'cpu'
))
)
assert
_equal
(
scaled_cpu
,
scaled_cuda
.
to
(
'cpu'
))
def
_get_data_dims_and_points_for_perspective
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment