Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
97b05a89
Unverified
Commit
97b05a89
authored
May 21, 2021
by
Nicolas Hug
Committed by
GitHub
May 21, 2021
Browse files
Use torch.testing.assert_close in test_transforms_tensor.py (#3885)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
55150bfb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
11 deletions
+18
-11
test/test_transforms_tensor.py
test/test_transforms_tensor.py
+18
-11
No files found.
test/test_transforms_tensor.py
View file @
97b05a89
...
@@ -10,6 +10,7 @@ import unittest
...
@@ -10,6 +10,7 @@ import unittest
from
typing
import
Sequence
from
typing
import
Sequence
from
common_utils
import
TransformsTester
,
get_tmp_dir
,
int_dtypes
,
float_dtypes
from
common_utils
import
TransformsTester
,
get_tmp_dir
,
int_dtypes
,
float_dtypes
from
_assert_utils
import
assert_equal
NEAREST
,
BILINEAR
,
BICUBIC
=
InterpolationMode
.
NEAREST
,
InterpolationMode
.
BILINEAR
,
InterpolationMode
.
BICUBIC
NEAREST
,
BILINEAR
,
BICUBIC
=
InterpolationMode
.
NEAREST
,
InterpolationMode
.
BILINEAR
,
InterpolationMode
.
BICUBIC
...
@@ -38,7 +39,7 @@ class Tester(TransformsTester):
...
@@ -38,7 +39,7 @@ class Tester(TransformsTester):
out1
=
transform
(
tensor
)
out1
=
transform
(
tensor
)
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
out2
=
s_transform
(
tensor
)
out2
=
s_transform
(
tensor
)
self
.
assert
True
(
out1
.
equal
(
out2
)
,
msg
=
msg
)
assert
_equal
(
out1
,
out2
,
msg
=
msg
)
def
_test_transform_vs_scripted_on_batch
(
self
,
transform
,
s_transform
,
batch_tensors
,
msg
=
None
):
def
_test_transform_vs_scripted_on_batch
(
self
,
transform
,
s_transform
,
batch_tensors
,
msg
=
None
):
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
...
@@ -48,11 +49,11 @@ class Tester(TransformsTester):
...
@@ -48,11 +49,11 @@ class Tester(TransformsTester):
img_tensor
=
batch_tensors
[
i
,
...]
img_tensor
=
batch_tensors
[
i
,
...]
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
transformed_img
=
transform
(
img_tensor
)
transformed_img
=
transform
(
img_tensor
)
self
.
assert
True
(
transformed_img
.
equal
(
transformed_batch
[
i
,
...]
)
,
msg
=
msg
)
assert
_equal
(
transformed_img
,
transformed_batch
[
i
,
...],
msg
=
msg
)
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
s_transformed_batch
=
s_transform
(
batch_tensors
)
s_transformed_batch
=
s_transform
(
batch_tensors
)
self
.
assert
True
(
transformed_batch
.
equal
(
s_transformed_batch
)
,
msg
=
msg
)
assert
_equal
(
transformed_batch
,
s_transformed_batch
,
msg
=
msg
)
def
_test_class_op
(
self
,
method
,
meth_kwargs
=
None
,
test_exact_match
=
True
,
**
match_kwargs
):
def
_test_class_op
(
self
,
method
,
meth_kwargs
=
None
,
test_exact_match
=
True
,
**
match_kwargs
):
if
meth_kwargs
is
None
:
if
meth_kwargs
is
None
:
...
@@ -75,7 +76,7 @@ class Tester(TransformsTester):
...
@@ -75,7 +76,7 @@ class Tester(TransformsTester):
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
self
.
assert
True
(
transformed_tensor
.
equal
(
transformed_tensor_script
)
)
assert
_equal
(
transformed_tensor
,
transformed_tensor_script
)
batch_tensors
=
self
.
_create_data_batch
(
height
=
23
,
width
=
34
,
channels
=
3
,
num_samples
=
4
,
device
=
self
.
device
)
batch_tensors
=
self
.
_create_data_batch
(
height
=
23
,
width
=
34
,
channels
=
3
,
num_samples
=
4
,
device
=
self
.
device
)
self
.
_test_transform_vs_scripted_on_batch
(
f
,
scripted_fn
,
batch_tensors
)
self
.
_test_transform_vs_scripted_on_batch
(
f
,
scripted_fn
,
batch_tensors
)
...
@@ -270,8 +271,11 @@ class Tester(TransformsTester):
...
@@ -270,8 +271,11 @@ class Tester(TransformsTester):
self
.
assertEqual
(
len
(
transformed_t_list
),
len
(
transformed_t_list_script
))
self
.
assertEqual
(
len
(
transformed_t_list
),
len
(
transformed_t_list_script
))
self
.
assertEqual
(
len
(
transformed_t_list_script
),
out_length
)
self
.
assertEqual
(
len
(
transformed_t_list_script
),
out_length
)
for
transformed_tensor
,
transformed_tensor_script
in
zip
(
transformed_t_list
,
transformed_t_list_script
):
for
transformed_tensor
,
transformed_tensor_script
in
zip
(
transformed_t_list
,
transformed_t_list_script
):
self
.
assertTrue
(
transformed_tensor
.
equal
(
transformed_tensor_script
),
assert_equal
(
msg
=
"{} vs {}"
.
format
(
transformed_tensor
,
transformed_tensor_script
))
transformed_tensor
,
transformed_tensor_script
,
msg
=
"{} vs {}"
.
format
(
transformed_tensor
,
transformed_tensor_script
),
)
# test for class interface
# test for class interface
fn
=
getattr
(
T
,
method
)(
**
meth_kwargs
)
fn
=
getattr
(
T
,
method
)(
**
meth_kwargs
)
...
@@ -289,8 +293,11 @@ class Tester(TransformsTester):
...
@@ -289,8 +293,11 @@ class Tester(TransformsTester):
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
transformed_img_list
=
fn
(
img_tensor
)
transformed_img_list
=
fn
(
img_tensor
)
for
transformed_img
,
transformed_batch
in
zip
(
transformed_img_list
,
transformed_batch_list
):
for
transformed_img
,
transformed_batch
in
zip
(
transformed_img_list
,
transformed_batch_list
):
self
.
assertTrue
(
transformed_img
.
equal
(
transformed_batch
[
i
,
...]),
assert_equal
(
msg
=
"{} vs {}"
.
format
(
transformed_img
,
transformed_batch
[
i
,
...]))
transformed_img
,
transformed_batch
[
i
,
...],
msg
=
"{} vs {}"
.
format
(
transformed_img
,
transformed_batch
[
i
,
...]),
)
with
get_tmp_dir
()
as
tmp_dir
:
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_op_list_{}.pt"
.
format
(
method
)))
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_op_list_{}.pt"
.
format
(
method
)))
...
@@ -505,7 +512,7 @@ class Tester(TransformsTester):
...
@@ -505,7 +512,7 @@ class Tester(TransformsTester):
transformed_batch
=
fn
(
batch_tensors
)
transformed_batch
=
fn
(
batch_tensors
)
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
s_transformed_batch
=
scripted_fn
(
batch_tensors
)
s_transformed_batch
=
scripted_fn
(
batch_tensors
)
self
.
assert
True
(
transformed_batch
.
equal
(
s_transformed_batch
)
)
assert
_equal
(
transformed_batch
,
s_transformed_batch
)
with
get_tmp_dir
()
as
tmp_dir
:
with
get_tmp_dir
()
as
tmp_dir
:
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_norm.pt"
))
scripted_fn
.
save
(
os
.
path
.
join
(
tmp_dir
,
"t_norm.pt"
))
...
@@ -525,7 +532,7 @@ class Tester(TransformsTester):
...
@@ -525,7 +532,7 @@ class Tester(TransformsTester):
transformed_tensor
=
transforms
(
tensor
)
transformed_tensor
=
transforms
(
tensor
)
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
self
.
assert
True
(
transformed_tensor
.
equal
(
transformed_tensor_script
)
,
msg
=
"{}"
.
format
(
transforms
))
assert
_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
"{}"
.
format
(
transforms
))
t
=
T
.
Compose
([
t
=
T
.
Compose
([
lambda
x
:
x
,
lambda
x
:
x
,
...
@@ -551,7 +558,7 @@ class Tester(TransformsTester):
...
@@ -551,7 +558,7 @@ class Tester(TransformsTester):
transformed_tensor
=
transforms
(
tensor
)
transformed_tensor
=
transforms
(
tensor
)
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
self
.
assert
True
(
transformed_tensor
.
equal
(
transformed_tensor_script
)
,
msg
=
"{}"
.
format
(
transforms
))
assert
_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
"{}"
.
format
(
transforms
))
if
torch
.
device
(
self
.
device
).
type
==
"cpu"
:
if
torch
.
device
(
self
.
device
).
type
==
"cpu"
:
# Can't check this twice, otherwise
# Can't check this twice, otherwise
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment