Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6cb73eb3
Unverified
Commit
6cb73eb3
authored
Jun 06, 2021
by
Sahil Goyal
Committed by
GitHub
Jun 05, 2021
Browse files
port some tests in test_functional_tensor to pytest (#3977)
parent
02e5bb40
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
141 additions
and
138 deletions
+141
-138
test/test_functional_tensor.py
test/test_functional_tensor.py
+141
-138
No files found.
test/test_functional_tensor.py
View file @
6cb73eb3
...
@@ -36,82 +36,6 @@ class Tester(unittest.TestCase):
...
@@ -36,82 +36,6 @@ class Tester(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
device
=
"cpu"
self
.
device
=
"cpu"
def
test_assert_image_tensor
(
self
):
shape
=
(
100
,)
tensor
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
self
.
device
)
list_of_methods
=
[(
F_t
.
_get_image_size
,
(
tensor
,
)),
(
F_t
.
vflip
,
(
tensor
,
)),
(
F_t
.
hflip
,
(
tensor
,
)),
(
F_t
.
crop
,
(
tensor
,
1
,
2
,
4
,
5
)),
(
F_t
.
adjust_brightness
,
(
tensor
,
0.
)),
(
F_t
.
adjust_contrast
,
(
tensor
,
1.
)),
(
F_t
.
adjust_hue
,
(
tensor
,
-
0.5
)),
(
F_t
.
adjust_saturation
,
(
tensor
,
2.
)),
(
F_t
.
center_crop
,
(
tensor
,
[
10
,
11
])),
(
F_t
.
five_crop
,
(
tensor
,
[
10
,
11
])),
(
F_t
.
ten_crop
,
(
tensor
,
[
10
,
11
])),
(
F_t
.
pad
,
(
tensor
,
[
2
,
],
2
,
"constant"
)),
(
F_t
.
resize
,
(
tensor
,
[
10
,
11
])),
(
F_t
.
perspective
,
(
tensor
,
[
0.2
,
])),
(
F_t
.
gaussian_blur
,
(
tensor
,
(
2
,
2
),
(
0.7
,
0.5
))),
(
F_t
.
invert
,
(
tensor
,
)),
(
F_t
.
posterize
,
(
tensor
,
0
)),
(
F_t
.
solarize
,
(
tensor
,
0.3
)),
(
F_t
.
adjust_sharpness
,
(
tensor
,
0.3
)),
(
F_t
.
autocontrast
,
(
tensor
,
)),
(
F_t
.
equalize
,
(
tensor
,
))]
for
func
,
args
in
list_of_methods
:
with
self
.
assertRaises
(
Exception
)
as
context
:
func
(
*
args
)
self
.
assertTrue
(
'Tensor is not a torch image.'
in
str
(
context
.
exception
))
def
test_vflip
(
self
):
script_vflip
=
torch
.
jit
.
script
(
F
.
vflip
)
img_tensor
,
pil_img
=
_create_data
(
16
,
18
,
device
=
self
.
device
)
vflipped_img
=
F
.
vflip
(
img_tensor
)
vflipped_pil_img
=
F
.
vflip
(
pil_img
)
_assert_equal_tensor_to_pil
(
vflipped_img
,
vflipped_pil_img
)
# scriptable function test
vflipped_img_script
=
script_vflip
(
img_tensor
)
assert_equal
(
vflipped_img
,
vflipped_img_script
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
vflip
)
def
test_hflip
(
self
):
script_hflip
=
torch
.
jit
.
script
(
F
.
hflip
)
img_tensor
,
pil_img
=
_create_data
(
16
,
18
,
device
=
self
.
device
)
hflipped_img
=
F
.
hflip
(
img_tensor
)
hflipped_pil_img
=
F
.
hflip
(
pil_img
)
_assert_equal_tensor_to_pil
(
hflipped_img
,
hflipped_pil_img
)
# scriptable function test
hflipped_img_script
=
script_hflip
(
img_tensor
)
assert_equal
(
hflipped_img
,
hflipped_img_script
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
hflip
)
def
test_crop
(
self
):
script_crop
=
torch
.
jit
.
script
(
F
.
crop
)
img_tensor
,
pil_img
=
_create_data
(
16
,
18
,
device
=
self
.
device
)
test_configs
=
[
(
1
,
2
,
4
,
5
),
# crop inside top-left corner
(
2
,
12
,
3
,
4
),
# crop inside top-right corner
(
8
,
3
,
5
,
6
),
# crop inside bottom-left corner
(
8
,
11
,
4
,
3
),
# crop inside bottom-right corner
]
for
top
,
left
,
height
,
width
in
test_configs
:
pil_img_cropped
=
F
.
crop
(
pil_img
,
top
,
left
,
height
,
width
)
img_tensor_cropped
=
F
.
crop
(
img_tensor
,
top
,
left
,
height
,
width
)
_assert_equal_tensor_to_pil
(
img_tensor_cropped
,
pil_img_cropped
)
img_tensor_cropped
=
script_crop
(
img_tensor
,
top
,
left
,
height
,
width
)
_assert_equal_tensor_to_pil
(
img_tensor_cropped
,
pil_img_cropped
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
crop
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
def
test_hsv2rgb
(
self
):
def
test_hsv2rgb
(
self
):
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_hsv2rgb
)
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_hsv2rgb
)
shape
=
(
3
,
100
,
150
)
shape
=
(
3
,
100
,
150
)
...
@@ -610,68 +534,6 @@ class Tester(unittest.TestCase):
...
@@ -610,68 +534,6 @@ class Tester(unittest.TestCase):
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
assert_equal
(
res1
,
res2
)
def
test_gaussian_blur
(
self
):
small_image_tensor
=
torch
.
from_numpy
(
np
.
arange
(
3
*
10
*
12
,
dtype
=
"uint8"
).
reshape
((
10
,
12
,
3
))
).
permute
(
2
,
0
,
1
).
to
(
self
.
device
)
large_image_tensor
=
torch
.
from_numpy
(
np
.
arange
(
26
*
28
,
dtype
=
"uint8"
).
reshape
((
1
,
26
,
28
))
).
to
(
self
.
device
)
scripted_transform
=
torch
.
jit
.
script
(
F
.
gaussian_blur
)
# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
# "3_3_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
# "3_3_0.5": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
# "3_5_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
# "3_5_0.5": ...
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
# "23_23_1.7": ...
# }
p
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'assets'
,
'gaussian_blur_opencv_results.pt'
)
true_cv2_results
=
torch
.
load
(
p
)
for
tensor
in
[
small_image_tensor
,
large_image_tensor
]:
for
dt
in
[
None
,
torch
.
float32
,
torch
.
float64
,
torch
.
float16
]:
if
dt
==
torch
.
float16
and
torch
.
device
(
self
.
device
).
type
==
"cpu"
:
# skip float16 on CPU case
continue
if
dt
is
not
None
:
tensor
=
tensor
.
to
(
dtype
=
dt
)
for
ksize
in
[(
3
,
3
),
[
3
,
5
],
(
23
,
23
)]:
for
sigma
in
[[
0.5
,
0.5
],
(
0.5
,
0.5
),
(
0.8
,
0.8
),
(
1.7
,
1.7
)]:
_ksize
=
(
ksize
,
ksize
)
if
isinstance
(
ksize
,
int
)
else
ksize
_sigma
=
sigma
[
0
]
if
sigma
is
not
None
else
None
shape
=
tensor
.
shape
gt_key
=
"{}_{}_{}__{}_{}_{}"
.
format
(
shape
[
-
2
],
shape
[
-
1
],
shape
[
-
3
],
_ksize
[
0
],
_ksize
[
1
],
_sigma
)
if
gt_key
not
in
true_cv2_results
:
continue
true_out
=
torch
.
tensor
(
true_cv2_results
[
gt_key
]
).
reshape
(
shape
[
-
2
],
shape
[
-
1
],
shape
[
-
3
]).
permute
(
2
,
0
,
1
).
to
(
tensor
)
for
fn
in
[
F
.
gaussian_blur
,
scripted_transform
]:
out
=
fn
(
tensor
,
kernel_size
=
ksize
,
sigma
=
sigma
)
torch
.
testing
.
assert_close
(
out
,
true_out
,
rtol
=
0.0
,
atol
=
1.0
,
check_stride
=
False
,
msg
=
"{}, {}"
.
format
(
ksize
,
sigma
)
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Skip if no CUDA device"
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Skip if no CUDA device"
)
class
CUDATester
(
Tester
):
class
CUDATester
(
Tester
):
...
@@ -1141,5 +1003,146 @@ def test_adjust_gamma(device, dtype, config):
...
@@ -1141,5 +1003,146 @@ def test_adjust_gamma(device, dtype, config):
)
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
'func, args'
,
[
(
F_t
.
_get_image_size
,
()),
(
F_t
.
vflip
,
()),
(
F_t
.
hflip
,
()),
(
F_t
.
crop
,
(
1
,
2
,
4
,
5
)),
(
F_t
.
adjust_brightness
,
(
0.
,
)),
(
F_t
.
adjust_contrast
,
(
1.
,
)),
(
F_t
.
adjust_hue
,
(
-
0.5
,
)),
(
F_t
.
adjust_saturation
,
(
2.
,
)),
(
F_t
.
center_crop
,
([
10
,
11
],
)),
(
F_t
.
five_crop
,
([
10
,
11
],
)),
(
F_t
.
ten_crop
,
([
10
,
11
],
)),
(
F_t
.
pad
,
([
2
,
],
2
,
"constant"
)),
(
F_t
.
resize
,
([
10
,
11
],
)),
(
F_t
.
perspective
,
([
0.2
,
])),
(
F_t
.
gaussian_blur
,
((
2
,
2
),
(
0.7
,
0.5
))),
(
F_t
.
invert
,
()),
(
F_t
.
posterize
,
(
0
,
)),
(
F_t
.
solarize
,
(
0.3
,
)),
(
F_t
.
adjust_sharpness
,
(
0.3
,
)),
(
F_t
.
autocontrast
,
()),
(
F_t
.
equalize
,
())
])
def
test_assert_image_tensor
(
device
,
func
,
args
):
shape
=
(
100
,)
tensor
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
device
)
with
pytest
.
raises
(
Exception
,
match
=
r
"Tensor is not a torch image."
):
func
(
tensor
,
*
args
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_vflip
(
device
):
script_vflip
=
torch
.
jit
.
script
(
F
.
vflip
)
img_tensor
,
pil_img
=
_create_data
(
16
,
18
,
device
=
device
)
vflipped_img
=
F
.
vflip
(
img_tensor
)
vflipped_pil_img
=
F
.
vflip
(
pil_img
)
_assert_equal_tensor_to_pil
(
vflipped_img
,
vflipped_pil_img
)
# scriptable function test
vflipped_img_script
=
script_vflip
(
img_tensor
)
assert_equal
(
vflipped_img
,
vflipped_img_script
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
vflip
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_hflip
(
device
):
script_hflip
=
torch
.
jit
.
script
(
F
.
hflip
)
img_tensor
,
pil_img
=
_create_data
(
16
,
18
,
device
=
device
)
hflipped_img
=
F
.
hflip
(
img_tensor
)
hflipped_pil_img
=
F
.
hflip
(
pil_img
)
_assert_equal_tensor_to_pil
(
hflipped_img
,
hflipped_pil_img
)
# scriptable function test
hflipped_img_script
=
script_hflip
(
img_tensor
)
assert_equal
(
hflipped_img
,
hflipped_img_script
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
hflip
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
'top, left, height, width'
,
[
(
1
,
2
,
4
,
5
),
# crop inside top-left corner
(
2
,
12
,
3
,
4
),
# crop inside top-right corner
(
8
,
3
,
5
,
6
),
# crop inside bottom-left corner
(
8
,
11
,
4
,
3
),
# crop inside bottom-right corner
])
def
test_crop
(
device
,
top
,
left
,
height
,
width
):
script_crop
=
torch
.
jit
.
script
(
F
.
crop
)
img_tensor
,
pil_img
=
_create_data
(
16
,
18
,
device
=
device
)
pil_img_cropped
=
F
.
crop
(
pil_img
,
top
,
left
,
height
,
width
)
img_tensor_cropped
=
F
.
crop
(
img_tensor
,
top
,
left
,
height
,
width
)
_assert_equal_tensor_to_pil
(
img_tensor_cropped
,
pil_img_cropped
)
img_tensor_cropped
=
script_crop
(
img_tensor
,
top
,
left
,
height
,
width
)
_assert_equal_tensor_to_pil
(
img_tensor_cropped
,
pil_img_cropped
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
crop
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
'image_size'
,
(
'small'
,
'large'
))
@
pytest
.
mark
.
parametrize
(
'dt'
,
[
None
,
torch
.
float32
,
torch
.
float64
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'ksize'
,
[(
3
,
3
),
[
3
,
5
],
(
23
,
23
)])
@
pytest
.
mark
.
parametrize
(
'sigma'
,
[[
0.5
,
0.5
],
(
0.5
,
0.5
),
(
0.8
,
0.8
),
(
1.7
,
1.7
)])
@
pytest
.
mark
.
parametrize
(
'fn'
,
[
F
.
gaussian_blur
,
torch
.
jit
.
script
(
F
.
gaussian_blur
)])
def
test_gaussian_blur
(
device
,
image_size
,
dt
,
ksize
,
sigma
,
fn
):
# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
# "3_3_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
# "3_3_0.5": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
# "3_5_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
# "3_5_0.5": ...
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
# "23_23_1.7": ...
# }
p
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'assets'
,
'gaussian_blur_opencv_results.pt'
)
true_cv2_results
=
torch
.
load
(
p
)
if
image_size
==
'small'
:
tensor
=
torch
.
from_numpy
(
np
.
arange
(
3
*
10
*
12
,
dtype
=
"uint8"
).
reshape
((
10
,
12
,
3
))
).
permute
(
2
,
0
,
1
).
to
(
device
)
else
:
tensor
=
torch
.
from_numpy
(
np
.
arange
(
26
*
28
,
dtype
=
"uint8"
).
reshape
((
1
,
26
,
28
))
).
to
(
device
)
if
dt
==
torch
.
float16
and
device
==
"cpu"
:
# skip float16 on CPU case
return
if
dt
is
not
None
:
tensor
=
tensor
.
to
(
dtype
=
dt
)
_ksize
=
(
ksize
,
ksize
)
if
isinstance
(
ksize
,
int
)
else
ksize
_sigma
=
sigma
[
0
]
if
sigma
is
not
None
else
None
shape
=
tensor
.
shape
gt_key
=
"{}_{}_{}__{}_{}_{}"
.
format
(
shape
[
-
2
],
shape
[
-
1
],
shape
[
-
3
],
_ksize
[
0
],
_ksize
[
1
],
_sigma
)
if
gt_key
not
in
true_cv2_results
:
return
true_out
=
torch
.
tensor
(
true_cv2_results
[
gt_key
]
).
reshape
(
shape
[
-
2
],
shape
[
-
1
],
shape
[
-
3
]).
permute
(
2
,
0
,
1
).
to
(
tensor
)
out
=
fn
(
tensor
,
kernel_size
=
ksize
,
sigma
=
sigma
)
torch
.
testing
.
assert_close
(
out
,
true_out
,
rtol
=
0.0
,
atol
=
1.0
,
check_stride
=
False
,
msg
=
"{}, {}"
.
format
(
ksize
,
sigma
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment