Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
182f80df
Unverified
Commit
182f80df
authored
Jun 07, 2021
by
Nicolas Hug
Committed by
GitHub
Jun 07, 2021
Browse files
Finish porting test_functional_tensor.py to pytest (#3990)
parent
a629a9b2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
109 deletions
+13
-109
test/test_functional_tensor.py
test/test_functional_tensor.py
+13
-109
No files found.
test/test_functional_tensor.py
View file @
182f80df
import
itertools
import
itertools
import
os
import
os
import
unittest
import
colorsys
import
colorsys
import
math
import
math
...
@@ -31,103 +30,8 @@ from typing import Dict, List, Sequence, Tuple
...
@@ -31,103 +30,8 @@ from typing import Dict, List, Sequence, Tuple
NEAREST
,
BILINEAR
,
BICUBIC
=
InterpolationMode
.
NEAREST
,
InterpolationMode
.
BILINEAR
,
InterpolationMode
.
BICUBIC
NEAREST
,
BILINEAR
,
BICUBIC
=
InterpolationMode
.
NEAREST
,
InterpolationMode
.
BILINEAR
,
InterpolationMode
.
BICUBIC
class
Tester
(
unittest
.
TestCase
):
@
needs_cuda
def
test_scale_channel
():
def
setUp
(
self
):
self
.
device
=
"cpu"
def
_test_rotate_all_options
(
self
,
tensor
,
pil_img
,
scripted_rotate
,
centers
):
img_size
=
pil_img
.
size
dt
=
tensor
.
dtype
for
r
in
[
NEAREST
,
]:
for
a
in
range
(
-
180
,
180
,
17
):
for
e
in
[
True
,
False
]:
for
c
in
centers
:
for
f
in
[
None
,
[
0
,
0
,
0
],
(
1
,
2
,
3
),
[
255
,
255
,
255
],
[
1
,
],
(
2.0
,
)]:
f_pil
=
int
(
f
[
0
])
if
f
is
not
None
and
len
(
f
)
==
1
else
f
out_pil_img
=
F
.
rotate
(
pil_img
,
angle
=
a
,
interpolation
=
r
,
expand
=
e
,
center
=
c
,
fill
=
f_pil
)
out_pil_tensor
=
torch
.
from_numpy
(
np
.
array
(
out_pil_img
).
transpose
((
2
,
0
,
1
)))
for
fn
in
[
F
.
rotate
,
scripted_rotate
]:
out_tensor
=
fn
(
tensor
,
angle
=
a
,
interpolation
=
r
,
expand
=
e
,
center
=
c
,
fill
=
f
).
cpu
()
if
out_tensor
.
dtype
!=
torch
.
uint8
:
out_tensor
=
out_tensor
.
to
(
torch
.
uint8
)
self
.
assertEqual
(
out_tensor
.
shape
,
out_pil_tensor
.
shape
,
msg
=
"{}: {} vs {}"
.
format
(
(
img_size
,
r
,
dt
,
a
,
e
,
c
),
out_tensor
.
shape
,
out_pil_tensor
.
shape
))
num_diff_pixels
=
(
out_tensor
!=
out_pil_tensor
).
sum
().
item
()
/
3.0
ratio_diff_pixels
=
num_diff_pixels
/
out_tensor
.
shape
[
-
1
]
/
out_tensor
.
shape
[
-
2
]
# Tolerance : less than 3% of different pixels
self
.
assertLess
(
ratio_diff_pixels
,
0.03
,
msg
=
"{}: {}
\n
{} vs
\n
{}"
.
format
(
(
img_size
,
r
,
dt
,
a
,
e
,
c
,
f
),
ratio_diff_pixels
,
out_tensor
[
0
,
:
7
,
:
7
],
out_pil_tensor
[
0
,
:
7
,
:
7
]
)
)
def
test_rotate
(
self
):
# Tests on square image
scripted_rotate
=
torch
.
jit
.
script
(
F
.
rotate
)
data
=
[
_create_data
(
26
,
26
,
device
=
self
.
device
),
_create_data
(
32
,
26
,
device
=
self
.
device
)]
for
tensor
,
pil_img
in
data
:
img_size
=
pil_img
.
size
centers
=
[
None
,
(
int
(
img_size
[
0
]
*
0.3
),
int
(
img_size
[
0
]
*
0.4
)),
[
int
(
img_size
[
0
]
*
0.5
),
int
(
img_size
[
0
]
*
0.6
)]
]
for
dt
in
[
None
,
torch
.
float32
,
torch
.
float64
,
torch
.
float16
]:
if
dt
==
torch
.
float16
and
torch
.
device
(
self
.
device
).
type
==
"cpu"
:
# skip float16 on CPU case
continue
if
dt
is
not
None
:
tensor
=
tensor
.
to
(
dtype
=
dt
)
self
.
_test_rotate_all_options
(
tensor
,
pil_img
,
scripted_rotate
,
centers
)
batch_tensors
=
_create_data_batch
(
26
,
36
,
num_samples
=
4
,
device
=
self
.
device
)
if
dt
is
not
None
:
batch_tensors
=
batch_tensors
.
to
(
dtype
=
dt
)
center
=
(
20
,
22
)
_test_fn_on_batch
(
batch_tensors
,
F
.
rotate
,
angle
=
32
,
interpolation
=
NEAREST
,
expand
=
True
,
center
=
center
)
tensor
,
pil_img
=
data
[
0
]
# assert deprecation warning and non-BC
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument resample is deprecated and will be removed"
):
res1
=
F
.
rotate
(
tensor
,
45
,
resample
=
2
)
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
# assert changed type warning
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument interpolation should be of type InterpolationMode"
):
res1
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
2
)
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Skip if no CUDA device"
)
class
CUDATester
(
Tester
):
def
setUp
(
self
):
self
.
device
=
"cuda"
def
test_scale_channel
(
self
):
"""Make sure that _scale_channel gives the same results on CPU and GPU as
"""Make sure that _scale_channel gives the same results on CPU and GPU as
histc or bincount are used depending on the device.
histc or bincount are used depending on the device.
"""
"""
...
@@ -1271,4 +1175,4 @@ def test_ten_crop(device):
...
@@ -1271,4 +1175,4 @@ def test_ten_crop(device):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unit
test
.
main
()
py
test
.
main
(
[
__file__
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment