Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
86f551d3
Unverified
Commit
86f551d3
authored
Feb 09, 2023
by
Philip Meier
Committed by
GitHub
Feb 09, 2023
Browse files
update usages of torch.testing internals (#7203)
parent
5ea8e013
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
9 deletions
+13
-9
test/prototype_common_utils.py
test/prototype_common_utils.py
+6
-4
test/test_prototype_datasets_builtin.py
test/test_prototype_datasets_builtin.py
+7
-5
No files found.
test/prototype_common_utils.py
View file @
86f551d3
...
@@ -14,7 +14,7 @@ import torch
...
@@ -14,7 +14,7 @@ import torch
import
torch.testing
import
torch.testing
from
datasets_utils
import
combinations_grid
from
datasets_utils
import
combinations_grid
from
torch.nn.functional
import
one_hot
from
torch.nn.functional
import
one_hot
from
torch.testing._comparison
import
assert_equal
as
_assert_equal
,
BooleanPair
,
NonePair
,
NumberPair
,
TensorLikePair
from
torch.testing._comparison
import
BooleanPair
,
NonePair
,
not_close_error_metas
,
NumberPair
,
TensorLikePair
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms.functional
import
convert_dtype_image_tensor
,
to_image_tensor
from
torchvision.prototype.transforms.functional
import
convert_dtype_image_tensor
,
to_image_tensor
from
torchvision.transforms.functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.functional_tensor
import
_max_value
as
get_max_value
...
@@ -73,7 +73,7 @@ class ImagePair(TensorLikePair):
...
@@ -73,7 +73,7 @@ class ImagePair(TensorLikePair):
actual
,
expected
=
self
.
_promote_for_comparison
(
actual
,
expected
)
actual
,
expected
=
self
.
_promote_for_comparison
(
actual
,
expected
)
mae
=
float
(
torch
.
abs
(
actual
-
expected
).
float
().
mean
())
mae
=
float
(
torch
.
abs
(
actual
-
expected
).
float
().
mean
())
if
mae
>
self
.
atol
:
if
mae
>
self
.
atol
:
raise
self
.
_make_error_meta
(
self
.
_fail
(
AssertionError
,
AssertionError
,
f
"The MAE of the images is
{
mae
}
, but only
{
self
.
atol
}
is allowed."
,
f
"The MAE of the images is
{
mae
}
, but only
{
self
.
atol
}
is allowed."
,
)
)
...
@@ -99,7 +99,7 @@ def assert_close(
...
@@ -99,7 +99,7 @@ def assert_close(
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
__tracebackhide__
=
True
__tracebackhide__
=
True
_assert_equal
(
error_metas
=
not_close_error_metas
(
actual
,
actual
,
expected
,
expected
,
pair_types
=
(
pair_types
=
(
...
@@ -117,10 +117,12 @@ def assert_close(
...
@@ -117,10 +117,12 @@ def assert_close(
check_dtype
=
check_dtype
,
check_dtype
=
check_dtype
,
check_layout
=
check_layout
,
check_layout
=
check_layout
,
check_stride
=
check_stride
,
check_stride
=
check_stride
,
msg
=
msg
,
**
kwargs
,
**
kwargs
,
)
)
if
error_metas
:
raise
error_metas
[
0
].
to_error
(
msg
)
assert_equal
=
functools
.
partial
(
assert_close
,
rtol
=
0
,
atol
=
0
)
assert_equal
=
functools
.
partial
(
assert_close
,
rtol
=
0
,
atol
=
0
)
...
...
test/test_prototype_datasets_builtin.py
View file @
86f551d3
import
functools
import
io
import
io
import
pickle
import
pickle
from
collections
import
deque
from
collections
import
deque
...
@@ -9,7 +8,7 @@ import torch
...
@@ -9,7 +8,7 @@ import torch
import
torchvision.prototype.transforms.utils
import
torchvision.prototype.transforms.utils
from
builtin_dataset_mocks
import
DATASET_MOCKS
,
parametrize_dataset_mocks
from
builtin_dataset_mocks
import
DATASET_MOCKS
,
parametrize_dataset_mocks
from
torch.testing._comparison
import
assert_equal
,
ObjectPair
,
TensorLikePair
from
torch.testing._comparison
import
not_close_error_metas
,
ObjectPair
,
TensorLikePair
# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish
# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
@@ -25,9 +24,12 @@ from torchvision.prototype import datapoints, datasets, transforms
...
@@ -25,9 +24,12 @@ from torchvision.prototype import datapoints, datasets, transforms
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
assert_samples_equal
=
functools
.
partial
(
def
assert_samples_equal
(
*
args
,
msg
=
None
,
**
kwargs
):
assert_equal
,
pair_types
=
(
TensorLikePair
,
ObjectPair
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
error_metas
=
not_close_error_metas
(
)
*
args
,
pair_types
=
(
TensorLikePair
,
ObjectPair
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
,
**
kwargs
)
if
error_metas
:
raise
error_metas
[
0
].
to_error
(
msg
)
def
extract_datapipes
(
dp
):
def
extract_datapipes
(
dp
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment