Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a629a9b2
Unverified
Commit
a629a9b2
authored
Jun 07, 2021
by
Vivek Kumar
Committed by
GitHub
Jun 07, 2021
Browse files
Port some tests to pytest in test_functional_tensor.py (#3988)
parent
7fb4ef57
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
249 additions
and
153 deletions
+249
-153
test/test_functional_tensor.py
test/test_functional_tensor.py
+249
-153
No files found.
test/test_functional_tensor.py
View file @
a629a9b2
...
@@ -36,160 +36,89 @@ class Tester(unittest.TestCase):
...
@@ -36,160 +36,89 @@ class Tester(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
device
=
"cpu"
self
.
device
=
"cpu"
def
test_hsv2rgb
(
self
):
def
_test_rotate_all_options
(
self
,
tensor
,
pil_img
,
scripted_rotate
,
centers
):
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_hsv2rgb
)
img_size
=
pil_img
.
size
shape
=
(
3
,
100
,
150
)
dt
=
tensor
.
dtype
for
_
in
range
(
10
):
for
r
in
[
NEAREST
,
]:
hsv_img
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
self
.
device
)
for
a
in
range
(
-
180
,
180
,
17
):
rgb_img
=
F_t
.
_hsv2rgb
(
hsv_img
)
for
e
in
[
True
,
False
]:
ft_img
=
rgb_img
.
permute
(
1
,
2
,
0
).
flatten
(
0
,
1
)
for
c
in
centers
:
for
f
in
[
None
,
[
0
,
0
,
0
],
(
1
,
2
,
3
),
[
255
,
255
,
255
],
[
1
,
],
(
2.0
,
)]:
h
,
s
,
v
,
=
hsv_img
.
unbind
(
0
)
f_pil
=
int
(
f
[
0
])
if
f
is
not
None
and
len
(
f
)
==
1
else
f
h
=
h
.
flatten
().
cpu
().
numpy
()
out_pil_img
=
F
.
rotate
(
pil_img
,
angle
=
a
,
interpolation
=
r
,
expand
=
e
,
center
=
c
,
fill
=
f_pil
)
s
=
s
.
flatten
().
cpu
().
numpy
()
out_pil_tensor
=
torch
.
from_numpy
(
np
.
array
(
out_pil_img
).
transpose
((
2
,
0
,
1
)))
v
=
v
.
flatten
().
cpu
().
numpy
()
for
fn
in
[
F
.
rotate
,
scripted_rotate
]:
out_tensor
=
fn
(
tensor
,
angle
=
a
,
interpolation
=
r
,
expand
=
e
,
center
=
c
,
fill
=
f
).
cpu
()
rgb
=
[]
for
h1
,
s1
,
v1
in
zip
(
h
,
s
,
v
):
if
out_tensor
.
dtype
!=
torch
.
uint8
:
rgb
.
append
(
colorsys
.
hsv_to_rgb
(
h1
,
s1
,
v1
))
out_tensor
=
out_tensor
.
to
(
torch
.
uint8
)
colorsys_img
=
torch
.
tensor
(
rgb
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
torch
.
testing
.
assert_close
(
ft_img
,
colorsys_img
,
rtol
=
0.0
,
atol
=
1e-5
)
self
.
assertEqual
(
out_tensor
.
shape
,
s_rgb_img
=
scripted_fn
(
hsv_img
)
out_pil_tensor
.
shape
,
torch
.
testing
.
assert_close
(
rgb_img
,
s_rgb_img
)
msg
=
"{}: {} vs {}"
.
format
(
(
img_size
,
r
,
dt
,
a
,
e
,
c
),
out_tensor
.
shape
,
out_pil_tensor
.
shape
batch_tensors
=
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
self
.
device
).
float
()
))
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_hsv2rgb
)
num_diff_pixels
=
(
out_tensor
!=
out_pil_tensor
).
sum
().
item
()
/
3.0
def
test_rgb2hsv
(
self
):
ratio_diff_pixels
=
num_diff_pixels
/
out_tensor
.
shape
[
-
1
]
/
out_tensor
.
shape
[
-
2
]
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_rgb2hsv
)
# Tolerance : less than 3% of different pixels
shape
=
(
3
,
150
,
100
)
self
.
assertLess
(
for
_
in
range
(
10
):
ratio_diff_pixels
,
rgb_img
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
self
.
device
)
0.03
,
hsv_img
=
F_t
.
_rgb2hsv
(
rgb_img
)
msg
=
"{}: {}
\n
{} vs
\n
{}"
.
format
(
ft_hsv_img
=
hsv_img
.
permute
(
1
,
2
,
0
).
flatten
(
0
,
1
)
(
img_size
,
r
,
dt
,
a
,
e
,
c
,
f
),
ratio_diff_pixels
,
r
,
g
,
b
,
=
rgb_img
.
unbind
(
dim
=-
3
)
out_tensor
[
0
,
:
7
,
:
7
],
r
=
r
.
flatten
().
cpu
().
numpy
()
out_pil_tensor
[
0
,
:
7
,
:
7
]
g
=
g
.
flatten
().
cpu
().
numpy
()
)
b
=
b
.
flatten
().
cpu
().
numpy
()
)
hsv
=
[]
def
test_rotate
(
self
):
for
r1
,
g1
,
b1
in
zip
(
r
,
g
,
b
):
# Tests on square image
hsv
.
append
(
colorsys
.
rgb_to_hsv
(
r1
,
g1
,
b1
))
scripted_rotate
=
torch
.
jit
.
script
(
F
.
rotate
)
colorsys_img
=
torch
.
tensor
(
hsv
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
data
=
[
_create_data
(
26
,
26
,
device
=
self
.
device
),
_create_data
(
32
,
26
,
device
=
self
.
device
)]
for
tensor
,
pil_img
in
data
:
ft_hsv_img_h
,
ft_hsv_img_sv
=
torch
.
split
(
ft_hsv_img
,
[
1
,
2
],
dim
=
1
)
colorsys_img_h
,
colorsys_img_sv
=
torch
.
split
(
colorsys_img
,
[
1
,
2
],
dim
=
1
)
img_size
=
pil_img
.
size
centers
=
[
max_diff_h
=
((
colorsys_img_h
*
2
*
math
.
pi
).
sin
()
-
(
ft_hsv_img_h
*
2
*
math
.
pi
).
sin
()).
abs
().
max
()
None
,
max_diff_sv
=
(
colorsys_img_sv
-
ft_hsv_img_sv
).
abs
().
max
()
(
int
(
img_size
[
0
]
*
0.3
),
int
(
img_size
[
0
]
*
0.4
)),
max_diff
=
max
(
max_diff_h
,
max_diff_sv
)
[
int
(
img_size
[
0
]
*
0.5
),
int
(
img_size
[
0
]
*
0.6
)]
self
.
assertLess
(
max_diff
,
1e-5
)
]
s_hsv_img
=
scripted_fn
(
rgb_img
)
for
dt
in
[
None
,
torch
.
float32
,
torch
.
float64
,
torch
.
float16
]:
torch
.
testing
.
assert_close
(
hsv_img
,
s_hsv_img
,
rtol
=
1e-5
,
atol
=
1e-7
)
if
dt
==
torch
.
float16
and
torch
.
device
(
self
.
device
).
type
==
"cpu"
:
batch_tensors
=
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
self
.
device
).
float
()
# skip float16 on CPU case
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_rgb2hsv
)
continue
def
test_rgb_to_grayscale
(
self
):
if
dt
is
not
None
:
script_rgb_to_grayscale
=
torch
.
jit
.
script
(
F
.
rgb_to_grayscale
)
tensor
=
tensor
.
to
(
dtype
=
dt
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
self
.
device
)
self
.
_test_rotate_all_options
(
tensor
,
pil_img
,
scripted_rotate
,
centers
)
for
num_output_channels
in
(
3
,
1
):
batch_tensors
=
_create_data_batch
(
26
,
36
,
num_samples
=
4
,
device
=
self
.
device
)
gray_pil_image
=
F
.
rgb_to_grayscale
(
pil_img
,
num_output_channels
=
num_output_channels
)
if
dt
is
not
None
:
gray_tensor
=
F
.
rgb_to_grayscale
(
img_tensor
,
num_output_channels
=
num_output_channels
)
batch_tensors
=
batch_tensors
.
to
(
dtype
=
dt
)
_assert_approx_equal_tensor_to_pil
(
gray_tensor
.
float
(),
gray_pil_image
,
tol
=
1.0
+
1e-10
,
agg_method
=
"max"
)
center
=
(
20
,
22
)
_test_fn_on_batch
(
s_gray_tensor
=
script_rgb_to_grayscale
(
img_tensor
,
num_output_channels
=
num_output_channels
)
batch_tensors
,
F
.
rotate
,
angle
=
32
,
interpolation
=
NEAREST
,
expand
=
True
,
center
=
center
assert_equal
(
s_gray_tensor
,
gray_tensor
)
)
tensor
,
pil_img
=
data
[
0
]
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
# assert deprecation warning and non-BC
_test_fn_on_batch
(
batch_tensors
,
F
.
rgb_to_grayscale
,
num_output_channels
=
num_output_channels
)
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument resample is deprecated and will be removed"
):
res1
=
F
.
rotate
(
tensor
,
45
,
resample
=
2
)
def
test_center_crop
(
self
):
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
script_center_crop
=
torch
.
jit
.
script
(
F
.
center_crop
)
assert_equal
(
res1
,
res2
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
self
.
device
)
cropped_pil_image
=
F
.
center_crop
(
pil_img
,
[
10
,
11
])
cropped_tensor
=
F
.
center_crop
(
img_tensor
,
[
10
,
11
])
_assert_equal_tensor_to_pil
(
cropped_tensor
,
cropped_pil_image
)
cropped_tensor
=
script_center_crop
(
img_tensor
,
[
10
,
11
])
_assert_equal_tensor_to_pil
(
cropped_tensor
,
cropped_pil_image
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
center_crop
,
output_size
=
[
10
,
11
])
def
test_five_crop
(
self
):
script_five_crop
=
torch
.
jit
.
script
(
F
.
five_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
self
.
device
)
cropped_pil_images
=
F
.
five_crop
(
pil_img
,
[
10
,
11
])
cropped_tensors
=
F
.
five_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
5
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
cropped_tensors
=
script_five_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
5
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
tuple_transformed_batches
=
F
.
five_crop
(
batch_tensors
,
[
10
,
11
])
for
i
in
range
(
len
(
batch_tensors
)):
img_tensor
=
batch_tensors
[
i
,
...]
tuple_transformed_imgs
=
F
.
five_crop
(
img_tensor
,
[
10
,
11
])
self
.
assertEqual
(
len
(
tuple_transformed_imgs
),
len
(
tuple_transformed_batches
))
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
assert_equal
(
true_transformed_img
,
transformed_img
)
# scriptable function test
s_tuple_transformed_batches
=
script_five_crop
(
batch_tensors
,
[
10
,
11
])
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
assert_equal
(
transformed_batch
,
s_transformed_batch
)
def
test_ten_crop
(
self
):
script_ten_crop
=
torch
.
jit
.
script
(
F
.
ten_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
self
.
device
)
cropped_pil_images
=
F
.
ten_crop
(
pil_img
,
[
10
,
11
])
cropped_tensors
=
F
.
ten_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
10
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
cropped_tensors
=
script_ten_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
10
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
tuple_transformed_batches
=
F
.
ten_crop
(
batch_tensors
,
[
10
,
11
])
for
i
in
range
(
len
(
batch_tensors
)):
img_tensor
=
batch_tensors
[
i
,
...]
tuple_transformed_imgs
=
F
.
ten_crop
(
img_tensor
,
[
10
,
11
])
self
.
assertEqual
(
len
(
tuple_transformed_imgs
),
len
(
tuple_transformed_batches
))
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
assert_equal
(
true_transformed_img
,
transformed_img
)
# scriptable function test
# assert changed type warning
s_tuple_transformed_batches
=
script_ten_crop
(
batch_tensors
,
[
10
,
11
])
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument interpolation should be of type InterpolationMode"
):
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
res1
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
2
)
assert_equal
(
transformed_batch
,
s_transformed_batch
)
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Skip if no CUDA device"
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Skip if no CUDA device"
)
...
@@ -1174,5 +1103,172 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
...
@@ -1174,5 +1103,172 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
)
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_hsv2rgb
(
device
):
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_hsv2rgb
)
shape
=
(
3
,
100
,
150
)
for
_
in
range
(
10
):
hsv_img
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
device
)
rgb_img
=
F_t
.
_hsv2rgb
(
hsv_img
)
ft_img
=
rgb_img
.
permute
(
1
,
2
,
0
).
flatten
(
0
,
1
)
h
,
s
,
v
,
=
hsv_img
.
unbind
(
0
)
h
=
h
.
flatten
().
cpu
().
numpy
()
s
=
s
.
flatten
().
cpu
().
numpy
()
v
=
v
.
flatten
().
cpu
().
numpy
()
rgb
=
[]
for
h1
,
s1
,
v1
in
zip
(
h
,
s
,
v
):
rgb
.
append
(
colorsys
.
hsv_to_rgb
(
h1
,
s1
,
v1
))
colorsys_img
=
torch
.
tensor
(
rgb
,
dtype
=
torch
.
float32
,
device
=
device
)
torch
.
testing
.
assert_close
(
ft_img
,
colorsys_img
,
rtol
=
0.0
,
atol
=
1e-5
)
s_rgb_img
=
scripted_fn
(
hsv_img
)
torch
.
testing
.
assert_close
(
rgb_img
,
s_rgb_img
)
batch_tensors
=
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
device
).
float
()
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_hsv2rgb
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_rgb2hsv
(
device
):
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_rgb2hsv
)
shape
=
(
3
,
150
,
100
)
for
_
in
range
(
10
):
rgb_img
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
device
)
hsv_img
=
F_t
.
_rgb2hsv
(
rgb_img
)
ft_hsv_img
=
hsv_img
.
permute
(
1
,
2
,
0
).
flatten
(
0
,
1
)
r
,
g
,
b
,
=
rgb_img
.
unbind
(
dim
=-
3
)
r
=
r
.
flatten
().
cpu
().
numpy
()
g
=
g
.
flatten
().
cpu
().
numpy
()
b
=
b
.
flatten
().
cpu
().
numpy
()
hsv
=
[]
for
r1
,
g1
,
b1
in
zip
(
r
,
g
,
b
):
hsv
.
append
(
colorsys
.
rgb_to_hsv
(
r1
,
g1
,
b1
))
colorsys_img
=
torch
.
tensor
(
hsv
,
dtype
=
torch
.
float32
,
device
=
device
)
ft_hsv_img_h
,
ft_hsv_img_sv
=
torch
.
split
(
ft_hsv_img
,
[
1
,
2
],
dim
=
1
)
colorsys_img_h
,
colorsys_img_sv
=
torch
.
split
(
colorsys_img
,
[
1
,
2
],
dim
=
1
)
max_diff_h
=
((
colorsys_img_h
*
2
*
math
.
pi
).
sin
()
-
(
ft_hsv_img_h
*
2
*
math
.
pi
).
sin
()).
abs
().
max
()
max_diff_sv
=
(
colorsys_img_sv
-
ft_hsv_img_sv
).
abs
().
max
()
max_diff
=
max
(
max_diff_h
,
max_diff_sv
)
assert
max_diff
<
1e-5
s_hsv_img
=
scripted_fn
(
rgb_img
)
torch
.
testing
.
assert_close
(
hsv_img
,
s_hsv_img
,
rtol
=
1e-5
,
atol
=
1e-7
)
batch_tensors
=
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
device
).
float
()
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_rgb2hsv
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
'num_output_channels'
,
(
3
,
1
))
def
test_rgb_to_grayscale
(
device
,
num_output_channels
):
script_rgb_to_grayscale
=
torch
.
jit
.
script
(
F
.
rgb_to_grayscale
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
device
)
gray_pil_image
=
F
.
rgb_to_grayscale
(
pil_img
,
num_output_channels
=
num_output_channels
)
gray_tensor
=
F
.
rgb_to_grayscale
(
img_tensor
,
num_output_channels
=
num_output_channels
)
_assert_approx_equal_tensor_to_pil
(
gray_tensor
.
float
(),
gray_pil_image
,
tol
=
1.0
+
1e-10
,
agg_method
=
"max"
)
s_gray_tensor
=
script_rgb_to_grayscale
(
img_tensor
,
num_output_channels
=
num_output_channels
)
assert_equal
(
s_gray_tensor
,
gray_tensor
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
rgb_to_grayscale
,
num_output_channels
=
num_output_channels
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_center_crop
(
device
):
script_center_crop
=
torch
.
jit
.
script
(
F
.
center_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
device
)
cropped_pil_image
=
F
.
center_crop
(
pil_img
,
[
10
,
11
])
cropped_tensor
=
F
.
center_crop
(
img_tensor
,
[
10
,
11
])
_assert_equal_tensor_to_pil
(
cropped_tensor
,
cropped_pil_image
)
cropped_tensor
=
script_center_crop
(
img_tensor
,
[
10
,
11
])
_assert_equal_tensor_to_pil
(
cropped_tensor
,
cropped_pil_image
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
center_crop
,
output_size
=
[
10
,
11
])
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_five_crop
(
device
):
script_five_crop
=
torch
.
jit
.
script
(
F
.
five_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
device
)
cropped_pil_images
=
F
.
five_crop
(
pil_img
,
[
10
,
11
])
cropped_tensors
=
F
.
five_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
5
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
cropped_tensors
=
script_five_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
5
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
tuple_transformed_batches
=
F
.
five_crop
(
batch_tensors
,
[
10
,
11
])
for
i
in
range
(
len
(
batch_tensors
)):
img_tensor
=
batch_tensors
[
i
,
...]
tuple_transformed_imgs
=
F
.
five_crop
(
img_tensor
,
[
10
,
11
])
assert
len
(
tuple_transformed_imgs
)
==
len
(
tuple_transformed_batches
)
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
assert_equal
(
true_transformed_img
,
transformed_img
)
# scriptable function test
s_tuple_transformed_batches
=
script_five_crop
(
batch_tensors
,
[
10
,
11
])
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
assert_equal
(
transformed_batch
,
s_transformed_batch
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_ten_crop
(
device
):
script_ten_crop
=
torch
.
jit
.
script
(
F
.
ten_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
device
)
cropped_pil_images
=
F
.
ten_crop
(
pil_img
,
[
10
,
11
])
cropped_tensors
=
F
.
ten_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
10
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
cropped_tensors
=
script_ten_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
10
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
tuple_transformed_batches
=
F
.
ten_crop
(
batch_tensors
,
[
10
,
11
])
for
i
in
range
(
len
(
batch_tensors
)):
img_tensor
=
batch_tensors
[
i
,
...]
tuple_transformed_imgs
=
F
.
ten_crop
(
img_tensor
,
[
10
,
11
])
assert
len
(
tuple_transformed_imgs
)
==
len
(
tuple_transformed_batches
)
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
assert_equal
(
true_transformed_img
,
transformed_img
)
# scriptable function test
s_tuple_transformed_batches
=
script_ten_crop
(
batch_tensors
,
[
10
,
11
])
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
assert_equal
(
transformed_batch
,
s_transformed_batch
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment