Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
86f551d3
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e25e525fde9bf0cca585c2b610a078b284f5bc87"
Unverified
Commit
86f551d3
authored
Feb 09, 2023
by
Philip Meier
Committed by
GitHub
Feb 09, 2023
Browse files
update usages of torch.testing internals (#7203)
parent
5ea8e013
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
9 deletions
+13
-9
test/prototype_common_utils.py
test/prototype_common_utils.py
+6
-4
test/test_prototype_datasets_builtin.py
test/test_prototype_datasets_builtin.py
+7
-5
No files found.
test/prototype_common_utils.py
View file @
86f551d3
...
@@ -14,7 +14,7 @@ import torch
...
@@ -14,7 +14,7 @@ import torch
import
torch.testing
import
torch.testing
from
datasets_utils
import
combinations_grid
from
datasets_utils
import
combinations_grid
from
torch.nn.functional
import
one_hot
from
torch.nn.functional
import
one_hot
from
torch.testing._comparison
import
assert_equal
as
_assert_equal
,
BooleanPair
,
NonePair
,
NumberPair
,
TensorLikePair
from
torch.testing._comparison
import
BooleanPair
,
NonePair
,
not_close_error_metas
,
NumberPair
,
TensorLikePair
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms.functional
import
convert_dtype_image_tensor
,
to_image_tensor
from
torchvision.prototype.transforms.functional
import
convert_dtype_image_tensor
,
to_image_tensor
from
torchvision.transforms.functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.functional_tensor
import
_max_value
as
get_max_value
...
@@ -73,7 +73,7 @@ class ImagePair(TensorLikePair):
...
@@ -73,7 +73,7 @@ class ImagePair(TensorLikePair):
actual
,
expected
=
self
.
_promote_for_comparison
(
actual
,
expected
)
actual
,
expected
=
self
.
_promote_for_comparison
(
actual
,
expected
)
mae
=
float
(
torch
.
abs
(
actual
-
expected
).
float
().
mean
())
mae
=
float
(
torch
.
abs
(
actual
-
expected
).
float
().
mean
())
if
mae
>
self
.
atol
:
if
mae
>
self
.
atol
:
raise
self
.
_make_error_meta
(
self
.
_fail
(
AssertionError
,
AssertionError
,
f
"The MAE of the images is
{
mae
}
, but only
{
self
.
atol
}
is allowed."
,
f
"The MAE of the images is
{
mae
}
, but only
{
self
.
atol
}
is allowed."
,
)
)
...
@@ -99,7 +99,7 @@ def assert_close(
...
@@ -99,7 +99,7 @@ def assert_close(
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
__tracebackhide__
=
True
__tracebackhide__
=
True
_assert_equal
(
error_metas
=
not_close_error_metas
(
actual
,
actual
,
expected
,
expected
,
pair_types
=
(
pair_types
=
(
...
@@ -117,10 +117,12 @@ def assert_close(
...
@@ -117,10 +117,12 @@ def assert_close(
check_dtype
=
check_dtype
,
check_dtype
=
check_dtype
,
check_layout
=
check_layout
,
check_layout
=
check_layout
,
check_stride
=
check_stride
,
check_stride
=
check_stride
,
msg
=
msg
,
**
kwargs
,
**
kwargs
,
)
)
if
error_metas
:
raise
error_metas
[
0
].
to_error
(
msg
)
assert_equal
=
functools
.
partial
(
assert_close
,
rtol
=
0
,
atol
=
0
)
assert_equal
=
functools
.
partial
(
assert_close
,
rtol
=
0
,
atol
=
0
)
...
...
test/test_prototype_datasets_builtin.py
View file @
86f551d3
import
functools
import
io
import
io
import
pickle
import
pickle
from
collections
import
deque
from
collections
import
deque
...
@@ -9,7 +8,7 @@ import torch
...
@@ -9,7 +8,7 @@ import torch
import
torchvision.prototype.transforms.utils
import
torchvision.prototype.transforms.utils
from
builtin_dataset_mocks
import
DATASET_MOCKS
,
parametrize_dataset_mocks
from
builtin_dataset_mocks
import
DATASET_MOCKS
,
parametrize_dataset_mocks
from
torch.testing._comparison
import
assert_equal
,
ObjectPair
,
TensorLikePair
from
torch.testing._comparison
import
not_close_error_metas
,
ObjectPair
,
TensorLikePair
# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish
# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
@@ -25,9 +24,12 @@ from torchvision.prototype import datapoints, datasets, transforms
...
@@ -25,9 +24,12 @@ from torchvision.prototype import datapoints, datasets, transforms
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
from
torchvision.prototype.datasets.utils._internal
import
INFINITE_BUFFER_SIZE
assert_samples_equal
=
functools
.
partial
(
def
assert_samples_equal
(
*
args
,
msg
=
None
,
**
kwargs
):
assert_equal
,
pair_types
=
(
TensorLikePair
,
ObjectPair
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
error_metas
=
not_close_error_metas
(
)
*
args
,
pair_types
=
(
TensorLikePair
,
ObjectPair
),
rtol
=
0
,
atol
=
0
,
equal_nan
=
True
,
**
kwargs
)
if
error_metas
:
raise
error_metas
[
0
].
to_error
(
msg
)
def
extract_datapipes
(
dp
):
def
extract_datapipes
(
dp
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment