Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a629a9b2
Unverified
Commit
a629a9b2
authored
Jun 07, 2021
by
Vivek Kumar
Committed by
GitHub
Jun 07, 2021
Browse files
Port some tests to pytest in test_functional_tensor.py (#3988)
parent
7fb4ef57
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
249 additions
and
153 deletions
+249
-153
test/test_functional_tensor.py
test/test_functional_tensor.py
+249
-153
No files found.
test/test_functional_tensor.py
View file @
a629a9b2
...
...
@@ -36,160 +36,89 @@ class Tester(unittest.TestCase):
def
setUp
(
self
):
self
.
device
=
"cpu"
def
test_hsv2rgb
(
self
):
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_hsv2rgb
)
shape
=
(
3
,
100
,
150
)
for
_
in
range
(
10
):
hsv_img
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
self
.
device
)
rgb_img
=
F_t
.
_hsv2rgb
(
hsv_img
)
ft_img
=
rgb_img
.
permute
(
1
,
2
,
0
).
flatten
(
0
,
1
)
h
,
s
,
v
,
=
hsv_img
.
unbind
(
0
)
h
=
h
.
flatten
().
cpu
().
numpy
()
s
=
s
.
flatten
().
cpu
().
numpy
()
v
=
v
.
flatten
().
cpu
().
numpy
()
rgb
=
[]
for
h1
,
s1
,
v1
in
zip
(
h
,
s
,
v
):
rgb
.
append
(
colorsys
.
hsv_to_rgb
(
h1
,
s1
,
v1
))
colorsys_img
=
torch
.
tensor
(
rgb
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
torch
.
testing
.
assert_close
(
ft_img
,
colorsys_img
,
rtol
=
0.0
,
atol
=
1e-5
)
s_rgb_img
=
scripted_fn
(
hsv_img
)
torch
.
testing
.
assert_close
(
rgb_img
,
s_rgb_img
)
batch_tensors
=
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
self
.
device
).
float
()
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_hsv2rgb
)
def
test_rgb2hsv
(
self
):
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_rgb2hsv
)
shape
=
(
3
,
150
,
100
)
for
_
in
range
(
10
):
rgb_img
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
self
.
device
)
hsv_img
=
F_t
.
_rgb2hsv
(
rgb_img
)
ft_hsv_img
=
hsv_img
.
permute
(
1
,
2
,
0
).
flatten
(
0
,
1
)
r
,
g
,
b
,
=
rgb_img
.
unbind
(
dim
=-
3
)
r
=
r
.
flatten
().
cpu
().
numpy
()
g
=
g
.
flatten
().
cpu
().
numpy
()
b
=
b
.
flatten
().
cpu
().
numpy
()
hsv
=
[]
for
r1
,
g1
,
b1
in
zip
(
r
,
g
,
b
):
hsv
.
append
(
colorsys
.
rgb_to_hsv
(
r1
,
g1
,
b1
))
colorsys_img
=
torch
.
tensor
(
hsv
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
ft_hsv_img_h
,
ft_hsv_img_sv
=
torch
.
split
(
ft_hsv_img
,
[
1
,
2
],
dim
=
1
)
colorsys_img_h
,
colorsys_img_sv
=
torch
.
split
(
colorsys_img
,
[
1
,
2
],
dim
=
1
)
max_diff_h
=
((
colorsys_img_h
*
2
*
math
.
pi
).
sin
()
-
(
ft_hsv_img_h
*
2
*
math
.
pi
).
sin
()).
abs
().
max
()
max_diff_sv
=
(
colorsys_img_sv
-
ft_hsv_img_sv
).
abs
().
max
()
max_diff
=
max
(
max_diff_h
,
max_diff_sv
)
self
.
assertLess
(
max_diff
,
1e-5
)
s_hsv_img
=
scripted_fn
(
rgb_img
)
torch
.
testing
.
assert_close
(
hsv_img
,
s_hsv_img
,
rtol
=
1e-5
,
atol
=
1e-7
)
batch_tensors
=
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
self
.
device
).
float
()
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_rgb2hsv
)
def
test_rgb_to_grayscale
(
self
):
script_rgb_to_grayscale
=
torch
.
jit
.
script
(
F
.
rgb_to_grayscale
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
self
.
device
)
for
num_output_channels
in
(
3
,
1
):
gray_pil_image
=
F
.
rgb_to_grayscale
(
pil_img
,
num_output_channels
=
num_output_channels
)
gray_tensor
=
F
.
rgb_to_grayscale
(
img_tensor
,
num_output_channels
=
num_output_channels
)
_assert_approx_equal_tensor_to_pil
(
gray_tensor
.
float
(),
gray_pil_image
,
tol
=
1.0
+
1e-10
,
agg_method
=
"max"
)
s_gray_tensor
=
script_rgb_to_grayscale
(
img_tensor
,
num_output_channels
=
num_output_channels
)
assert_equal
(
s_gray_tensor
,
gray_tensor
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
rgb_to_grayscale
,
num_output_channels
=
num_output_channels
)
def
test_center_crop
(
self
):
script_center_crop
=
torch
.
jit
.
script
(
F
.
center_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
self
.
device
)
cropped_pil_image
=
F
.
center_crop
(
pil_img
,
[
10
,
11
])
cropped_tensor
=
F
.
center_crop
(
img_tensor
,
[
10
,
11
])
_assert_equal_tensor_to_pil
(
cropped_tensor
,
cropped_pil_image
)
cropped_tensor
=
script_center_crop
(
img_tensor
,
[
10
,
11
])
_assert_equal_tensor_to_pil
(
cropped_tensor
,
cropped_pil_image
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
center_crop
,
output_size
=
[
10
,
11
])
def
test_five_crop
(
self
):
script_five_crop
=
torch
.
jit
.
script
(
F
.
five_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
self
.
device
)
cropped_pil_images
=
F
.
five_crop
(
pil_img
,
[
10
,
11
])
cropped_tensors
=
F
.
five_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
5
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
cropped_tensors
=
script_five_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
5
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
tuple_transformed_batches
=
F
.
five_crop
(
batch_tensors
,
[
10
,
11
])
for
i
in
range
(
len
(
batch_tensors
)):
img_tensor
=
batch_tensors
[
i
,
...]
tuple_transformed_imgs
=
F
.
five_crop
(
img_tensor
,
[
10
,
11
])
self
.
assertEqual
(
len
(
tuple_transformed_imgs
),
len
(
tuple_transformed_batches
))
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
assert_equal
(
true_transformed_img
,
transformed_img
)
# scriptable function test
s_tuple_transformed_batches
=
script_five_crop
(
batch_tensors
,
[
10
,
11
])
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
assert_equal
(
transformed_batch
,
s_transformed_batch
)
def
test_ten_crop
(
self
):
script_ten_crop
=
torch
.
jit
.
script
(
F
.
ten_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
self
.
device
)
cropped_pil_images
=
F
.
ten_crop
(
pil_img
,
[
10
,
11
])
cropped_tensors
=
F
.
ten_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
10
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
cropped_tensors
=
script_ten_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
10
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
self
.
device
)
tuple_transformed_batches
=
F
.
ten_crop
(
batch_tensors
,
[
10
,
11
])
for
i
in
range
(
len
(
batch_tensors
)):
img_tensor
=
batch_tensors
[
i
,
...]
tuple_transformed_imgs
=
F
.
ten_crop
(
img_tensor
,
[
10
,
11
])
self
.
assertEqual
(
len
(
tuple_transformed_imgs
),
len
(
tuple_transformed_batches
))
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
assert_equal
(
true_transformed_img
,
transformed_img
)
def
_test_rotate_all_options
(
self
,
tensor
,
pil_img
,
scripted_rotate
,
centers
):
img_size
=
pil_img
.
size
dt
=
tensor
.
dtype
for
r
in
[
NEAREST
,
]:
for
a
in
range
(
-
180
,
180
,
17
):
for
e
in
[
True
,
False
]:
for
c
in
centers
:
for
f
in
[
None
,
[
0
,
0
,
0
],
(
1
,
2
,
3
),
[
255
,
255
,
255
],
[
1
,
],
(
2.0
,
)]:
f_pil
=
int
(
f
[
0
])
if
f
is
not
None
and
len
(
f
)
==
1
else
f
out_pil_img
=
F
.
rotate
(
pil_img
,
angle
=
a
,
interpolation
=
r
,
expand
=
e
,
center
=
c
,
fill
=
f_pil
)
out_pil_tensor
=
torch
.
from_numpy
(
np
.
array
(
out_pil_img
).
transpose
((
2
,
0
,
1
)))
for
fn
in
[
F
.
rotate
,
scripted_rotate
]:
out_tensor
=
fn
(
tensor
,
angle
=
a
,
interpolation
=
r
,
expand
=
e
,
center
=
c
,
fill
=
f
).
cpu
()
if
out_tensor
.
dtype
!=
torch
.
uint8
:
out_tensor
=
out_tensor
.
to
(
torch
.
uint8
)
self
.
assertEqual
(
out_tensor
.
shape
,
out_pil_tensor
.
shape
,
msg
=
"{}: {} vs {}"
.
format
(
(
img_size
,
r
,
dt
,
a
,
e
,
c
),
out_tensor
.
shape
,
out_pil_tensor
.
shape
))
num_diff_pixels
=
(
out_tensor
!=
out_pil_tensor
).
sum
().
item
()
/
3.0
ratio_diff_pixels
=
num_diff_pixels
/
out_tensor
.
shape
[
-
1
]
/
out_tensor
.
shape
[
-
2
]
# Tolerance : less than 3% of different pixels
self
.
assertLess
(
ratio_diff_pixels
,
0.03
,
msg
=
"{}: {}
\n
{} vs
\n
{}"
.
format
(
(
img_size
,
r
,
dt
,
a
,
e
,
c
,
f
),
ratio_diff_pixels
,
out_tensor
[
0
,
:
7
,
:
7
],
out_pil_tensor
[
0
,
:
7
,
:
7
]
)
)
def
test_rotate
(
self
):
# Tests on square image
scripted_rotate
=
torch
.
jit
.
script
(
F
.
rotate
)
data
=
[
_create_data
(
26
,
26
,
device
=
self
.
device
),
_create_data
(
32
,
26
,
device
=
self
.
device
)]
for
tensor
,
pil_img
in
data
:
img_size
=
pil_img
.
size
centers
=
[
None
,
(
int
(
img_size
[
0
]
*
0.3
),
int
(
img_size
[
0
]
*
0.4
)),
[
int
(
img_size
[
0
]
*
0.5
),
int
(
img_size
[
0
]
*
0.6
)]
]
for
dt
in
[
None
,
torch
.
float32
,
torch
.
float64
,
torch
.
float16
]:
if
dt
==
torch
.
float16
and
torch
.
device
(
self
.
device
).
type
==
"cpu"
:
# skip float16 on CPU case
continue
if
dt
is
not
None
:
tensor
=
tensor
.
to
(
dtype
=
dt
)
self
.
_test_rotate_all_options
(
tensor
,
pil_img
,
scripted_rotate
,
centers
)
batch_tensors
=
_create_data_batch
(
26
,
36
,
num_samples
=
4
,
device
=
self
.
device
)
if
dt
is
not
None
:
batch_tensors
=
batch_tensors
.
to
(
dtype
=
dt
)
center
=
(
20
,
22
)
_test_fn_on_batch
(
batch_tensors
,
F
.
rotate
,
angle
=
32
,
interpolation
=
NEAREST
,
expand
=
True
,
center
=
center
)
tensor
,
pil_img
=
data
[
0
]
# assert deprecation warning and non-BC
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument resample is deprecated and will be removed"
):
res1
=
F
.
rotate
(
tensor
,
45
,
resample
=
2
)
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
# scriptable function test
s_tuple_transformed_batches
=
script_ten_crop
(
batch_tensors
,
[
10
,
11
])
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
assert_equal
(
transformed_batch
,
s_transformed_batch
)
# assert changed type warning
with
self
.
assertWarnsRegex
(
UserWarning
,
r
"Argument interpolation should be of type InterpolationMode"
):
res1
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
2
)
res2
=
F
.
rotate
(
tensor
,
45
,
interpolation
=
BILINEAR
)
assert_equal
(
res1
,
res2
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Skip if no CUDA device"
)
...
...
@@ -1174,5 +1103,172 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_hsv2rgb
(
device
):
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_hsv2rgb
)
shape
=
(
3
,
100
,
150
)
for
_
in
range
(
10
):
hsv_img
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
device
)
rgb_img
=
F_t
.
_hsv2rgb
(
hsv_img
)
ft_img
=
rgb_img
.
permute
(
1
,
2
,
0
).
flatten
(
0
,
1
)
h
,
s
,
v
,
=
hsv_img
.
unbind
(
0
)
h
=
h
.
flatten
().
cpu
().
numpy
()
s
=
s
.
flatten
().
cpu
().
numpy
()
v
=
v
.
flatten
().
cpu
().
numpy
()
rgb
=
[]
for
h1
,
s1
,
v1
in
zip
(
h
,
s
,
v
):
rgb
.
append
(
colorsys
.
hsv_to_rgb
(
h1
,
s1
,
v1
))
colorsys_img
=
torch
.
tensor
(
rgb
,
dtype
=
torch
.
float32
,
device
=
device
)
torch
.
testing
.
assert_close
(
ft_img
,
colorsys_img
,
rtol
=
0.0
,
atol
=
1e-5
)
s_rgb_img
=
scripted_fn
(
hsv_img
)
torch
.
testing
.
assert_close
(
rgb_img
,
s_rgb_img
)
batch_tensors
=
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
device
).
float
()
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_hsv2rgb
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_rgb2hsv
(
device
):
scripted_fn
=
torch
.
jit
.
script
(
F_t
.
_rgb2hsv
)
shape
=
(
3
,
150
,
100
)
for
_
in
range
(
10
):
rgb_img
=
torch
.
rand
(
*
shape
,
dtype
=
torch
.
float
,
device
=
device
)
hsv_img
=
F_t
.
_rgb2hsv
(
rgb_img
)
ft_hsv_img
=
hsv_img
.
permute
(
1
,
2
,
0
).
flatten
(
0
,
1
)
r
,
g
,
b
,
=
rgb_img
.
unbind
(
dim
=-
3
)
r
=
r
.
flatten
().
cpu
().
numpy
()
g
=
g
.
flatten
().
cpu
().
numpy
()
b
=
b
.
flatten
().
cpu
().
numpy
()
hsv
=
[]
for
r1
,
g1
,
b1
in
zip
(
r
,
g
,
b
):
hsv
.
append
(
colorsys
.
rgb_to_hsv
(
r1
,
g1
,
b1
))
colorsys_img
=
torch
.
tensor
(
hsv
,
dtype
=
torch
.
float32
,
device
=
device
)
ft_hsv_img_h
,
ft_hsv_img_sv
=
torch
.
split
(
ft_hsv_img
,
[
1
,
2
],
dim
=
1
)
colorsys_img_h
,
colorsys_img_sv
=
torch
.
split
(
colorsys_img
,
[
1
,
2
],
dim
=
1
)
max_diff_h
=
((
colorsys_img_h
*
2
*
math
.
pi
).
sin
()
-
(
ft_hsv_img_h
*
2
*
math
.
pi
).
sin
()).
abs
().
max
()
max_diff_sv
=
(
colorsys_img_sv
-
ft_hsv_img_sv
).
abs
().
max
()
max_diff
=
max
(
max_diff_h
,
max_diff_sv
)
assert
max_diff
<
1e-5
s_hsv_img
=
scripted_fn
(
rgb_img
)
torch
.
testing
.
assert_close
(
hsv_img
,
s_hsv_img
,
rtol
=
1e-5
,
atol
=
1e-7
)
batch_tensors
=
_create_data_batch
(
120
,
100
,
num_samples
=
4
,
device
=
device
).
float
()
_test_fn_on_batch
(
batch_tensors
,
F_t
.
_rgb2hsv
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
'num_output_channels'
,
(
3
,
1
))
def
test_rgb_to_grayscale
(
device
,
num_output_channels
):
script_rgb_to_grayscale
=
torch
.
jit
.
script
(
F
.
rgb_to_grayscale
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
device
)
gray_pil_image
=
F
.
rgb_to_grayscale
(
pil_img
,
num_output_channels
=
num_output_channels
)
gray_tensor
=
F
.
rgb_to_grayscale
(
img_tensor
,
num_output_channels
=
num_output_channels
)
_assert_approx_equal_tensor_to_pil
(
gray_tensor
.
float
(),
gray_pil_image
,
tol
=
1.0
+
1e-10
,
agg_method
=
"max"
)
s_gray_tensor
=
script_rgb_to_grayscale
(
img_tensor
,
num_output_channels
=
num_output_channels
)
assert_equal
(
s_gray_tensor
,
gray_tensor
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
rgb_to_grayscale
,
num_output_channels
=
num_output_channels
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_center_crop
(
device
):
script_center_crop
=
torch
.
jit
.
script
(
F
.
center_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
device
)
cropped_pil_image
=
F
.
center_crop
(
pil_img
,
[
10
,
11
])
cropped_tensor
=
F
.
center_crop
(
img_tensor
,
[
10
,
11
])
_assert_equal_tensor_to_pil
(
cropped_tensor
,
cropped_pil_image
)
cropped_tensor
=
script_center_crop
(
img_tensor
,
[
10
,
11
])
_assert_equal_tensor_to_pil
(
cropped_tensor
,
cropped_pil_image
)
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
_test_fn_on_batch
(
batch_tensors
,
F
.
center_crop
,
output_size
=
[
10
,
11
])
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_five_crop
(
device
):
script_five_crop
=
torch
.
jit
.
script
(
F
.
five_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
device
)
cropped_pil_images
=
F
.
five_crop
(
pil_img
,
[
10
,
11
])
cropped_tensors
=
F
.
five_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
5
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
cropped_tensors
=
script_five_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
5
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
tuple_transformed_batches
=
F
.
five_crop
(
batch_tensors
,
[
10
,
11
])
for
i
in
range
(
len
(
batch_tensors
)):
img_tensor
=
batch_tensors
[
i
,
...]
tuple_transformed_imgs
=
F
.
five_crop
(
img_tensor
,
[
10
,
11
])
assert
len
(
tuple_transformed_imgs
)
==
len
(
tuple_transformed_batches
)
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
assert_equal
(
true_transformed_img
,
transformed_img
)
# scriptable function test
s_tuple_transformed_batches
=
script_five_crop
(
batch_tensors
,
[
10
,
11
])
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
assert_equal
(
transformed_batch
,
s_transformed_batch
)
@
pytest
.
mark
.
parametrize
(
'device'
,
cpu_and_gpu
())
def
test_ten_crop
(
device
):
script_ten_crop
=
torch
.
jit
.
script
(
F
.
ten_crop
)
img_tensor
,
pil_img
=
_create_data
(
32
,
34
,
device
=
device
)
cropped_pil_images
=
F
.
ten_crop
(
pil_img
,
[
10
,
11
])
cropped_tensors
=
F
.
ten_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
10
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
cropped_tensors
=
script_ten_crop
(
img_tensor
,
[
10
,
11
])
for
i
in
range
(
10
):
_assert_equal_tensor_to_pil
(
cropped_tensors
[
i
],
cropped_pil_images
[
i
])
batch_tensors
=
_create_data_batch
(
16
,
18
,
num_samples
=
4
,
device
=
device
)
tuple_transformed_batches
=
F
.
ten_crop
(
batch_tensors
,
[
10
,
11
])
for
i
in
range
(
len
(
batch_tensors
)):
img_tensor
=
batch_tensors
[
i
,
...]
tuple_transformed_imgs
=
F
.
ten_crop
(
img_tensor
,
[
10
,
11
])
assert
len
(
tuple_transformed_imgs
)
==
len
(
tuple_transformed_batches
)
for
j
in
range
(
len
(
tuple_transformed_imgs
)):
true_transformed_img
=
tuple_transformed_imgs
[
j
]
transformed_img
=
tuple_transformed_batches
[
j
][
i
,
...]
assert_equal
(
true_transformed_img
,
transformed_img
)
# scriptable function test
s_tuple_transformed_batches
=
script_ten_crop
(
batch_tensors
,
[
10
,
11
])
for
transformed_batch
,
s_transformed_batch
in
zip
(
tuple_transformed_batches
,
s_tuple_transformed_batches
):
assert_equal
(
transformed_batch
,
s_transformed_batch
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment