Unverified Commit 86f551d3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

update usages of torch.testing internals (#7203)

parent 5ea8e013
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
import torch.testing import torch.testing
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value from torchvision.transforms.functional_tensor import _max_value as get_max_value
...@@ -73,7 +73,7 @@ class ImagePair(TensorLikePair): ...@@ -73,7 +73,7 @@ class ImagePair(TensorLikePair):
actual, expected = self._promote_for_comparison(actual, expected) actual, expected = self._promote_for_comparison(actual, expected)
mae = float(torch.abs(actual - expected).float().mean()) mae = float(torch.abs(actual - expected).float().mean())
if mae > self.atol: if mae > self.atol:
raise self._make_error_meta( self._fail(
AssertionError, AssertionError,
f"The MAE of the images is {mae}, but only {self.atol} is allowed.", f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
) )
...@@ -99,7 +99,7 @@ def assert_close( ...@@ -99,7 +99,7 @@ def assert_close(
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison""" """Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
__tracebackhide__ = True __tracebackhide__ = True
_assert_equal( error_metas = not_close_error_metas(
actual, actual,
expected, expected,
pair_types=( pair_types=(
...@@ -117,10 +117,12 @@ def assert_close( ...@@ -117,10 +117,12 @@ def assert_close(
check_dtype=check_dtype, check_dtype=check_dtype,
check_layout=check_layout, check_layout=check_layout,
check_stride=check_stride, check_stride=check_stride,
msg=msg,
**kwargs, **kwargs,
) )
if error_metas:
raise error_metas[0].to_error(msg)
assert_equal = functools.partial(assert_close, rtol=0, atol=0) assert_equal = functools.partial(assert_close, rtol=0, atol=0)
......
import functools
import io import io
import pickle import pickle
from collections import deque from collections import deque
...@@ -9,7 +8,7 @@ import torch ...@@ -9,7 +8,7 @@ import torch
import torchvision.prototype.transforms.utils import torchvision.prototype.transforms.utils
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair
# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish # TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -25,9 +24,12 @@ from torchvision.prototype import datapoints, datasets, transforms ...@@ -25,9 +24,12 @@ from torchvision.prototype import datapoints, datasets, transforms
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
assert_samples_equal = functools.partial( def assert_samples_equal(*args, msg=None, **kwargs):
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True error_metas = not_close_error_metas(
) *args, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True, **kwargs
)
if error_metas:
raise error_metas[0].to_error(msg)
def extract_datapipes(dp): def extract_datapipes(dp):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment