"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4b7fe044e30249c9480498eb0ede4f15de58fe03"
Unverified Commit 3991ab99 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Promote prototype transforms to beta status (#7261)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent d010e82f
...@@ -584,11 +584,8 @@ class DatasetTestCase(unittest.TestCase): ...@@ -584,11 +584,8 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs @test_all_configs
def test_transforms_v2_wrapper(self, config): def test_transforms_v2_wrapper(self, config):
# Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs from torchvision.datapoints import wrap_dataset_for_transforms_v2
# to be available with the next release when v2 is released. Thus, if this import somehow fails on the release from torchvision.datapoints._datapoint import Datapoint
# branch, we screwed up the roll-out
from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2
from torchvision.prototype.datapoints._datapoint import Datapoint
try: try:
with self.create_dataset(config) as (dataset, _): with self.create_dataset(config) as (dataset, _):
...@@ -596,12 +593,13 @@ class DatasetTestCase(unittest.TestCase): ...@@ -596,12 +593,13 @@ class DatasetTestCase(unittest.TestCase):
wrapped_sample = wrapped_dataset[0] wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error: except TypeError as error:
if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"): msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
return if str(error).startswith(msg):
pytest.skip(msg)
raise error raise error
except RuntimeError as error: except RuntimeError as error:
if "currently not supported by this wrapper" in str(error): if "currently not supported by this wrapper" in str(error):
return pytest.skip("Config is currently not supported by this wrapper")
raise error raise error
......
...@@ -12,12 +12,13 @@ import PIL.Image ...@@ -12,12 +12,13 @@ import PIL.Image
import pytest import pytest
import torch import torch
import torch.testing import torch.testing
import torchvision.prototype.datapoints as proto_datapoints
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value from torchvision.transforms.functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor
__all__ = [ __all__ = [
"assert_close", "assert_close",
...@@ -457,7 +458,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64): ...@@ -457,7 +458,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return datapoints.Label(data, categories=categories) return proto_datapoints.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
...@@ -481,7 +482,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int ...@@ -481,7 +482,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int
# since `one_hot` only supports int64 # since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype) data = one_hot(label, num_classes=num_categories).to(dtype)
return datapoints.OneHotLabel(data, categories=categories) return proto_datapoints.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
......
import collections.abc import collections.abc
import pytest import pytest
import torchvision.prototype.transforms.functional as F import torchvision.transforms.v2.functional as F
from prototype_common_utils import InfoBase, TestMark from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
from torchvision.prototype import datapoints from torchvision import datapoints
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
......
...@@ -8,7 +8,7 @@ import PIL.Image ...@@ -8,7 +8,7 @@ import PIL.Image
import pytest import pytest
import torch.testing import torch.testing
import torchvision.ops import torchvision.ops
import torchvision.prototype.transforms.functional as F import torchvision.transforms.v2.functional as F
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
...@@ -28,7 +28,7 @@ from prototype_common_utils import ( ...@@ -28,7 +28,7 @@ from prototype_common_utils import (
TestMark, TestMark,
) )
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding
__all__ = ["KernelInfo", "KERNEL_INFOS"] __all__ = ["KernelInfo", "KERNEL_INFOS"]
...@@ -2383,19 +2383,18 @@ KERNEL_INFOS.extend( ...@@ -2383,19 +2383,18 @@ KERNEL_INFOS.extend(
def sample_inputs_uniform_temporal_subsample_video(): def sample_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]): for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]):
for temporal_dim in [-4, len(video_loader.shape) - 4]: yield ArgsKwargs(video_loader, num_samples=2)
yield ArgsKwargs(video_loader, num_samples=2, temporal_dim=temporal_dim)
def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4): def reference_uniform_temporal_subsample_video(x, num_samples):
# Copy-pasted from # Copy-pasted from
# https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19 # https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
t = x.shape[temporal_dim] t = x.shape[-4]
assert num_samples > 0 and t > 0 assert num_samples > 0 and t > 0
# Sample by nearest neighbor interpolation if num_samples > t. # Sample by nearest neighbor interpolation if num_samples > t.
indices = torch.linspace(0, t - 1, num_samples) indices = torch.linspace(0, t - 1, num_samples)
indices = torch.clamp(indices, 0, t - 1).long() indices = torch.clamp(indices, 0, t - 1).long()
return torch.index_select(x, temporal_dim, indices) return torch.index_select(x, -4, indices)
def reference_inputs_uniform_temporal_subsample_video(): def reference_inputs_uniform_temporal_subsample_video():
...@@ -2410,12 +2409,5 @@ KERNEL_INFOS.append( ...@@ -2410,12 +2409,5 @@ KERNEL_INFOS.append(
sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video, sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
reference_fn=reference_uniform_temporal_subsample_video, reference_fn=reference_uniform_temporal_subsample_video,
reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video, reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
test_marks=[
TestMark(
("TestKernels", "test_batched_vs_single"),
pytest.mark.skip("Positive `temporal_dim` arguments are not equivalent for batched and single inputs"),
condition=lambda args_kwargs: args_kwargs.kwargs.get("temporal_dim") >= 0,
),
],
) )
) )
...@@ -5,8 +5,8 @@ import torch ...@@ -5,8 +5,8 @@ import torch
from PIL import Image from PIL import Image
from torchvision import datasets from torchvision import datapoints, datasets
from torchvision.prototype import datapoints from torchvision.prototype import datapoints as proto_datapoints
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -24,38 +24,38 @@ from torchvision.prototype import datapoints ...@@ -24,38 +24,38 @@ from torchvision.prototype import datapoints
], ],
) )
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad): def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
datapoint = datapoints.Label(data, requires_grad=input_requires_grad) datapoint = proto_datapoints.Label(data, requires_grad=input_requires_grad)
assert datapoint.requires_grad is expected_requires_grad assert datapoint.requires_grad is expected_requires_grad
def test_isinstance(): def test_isinstance():
assert isinstance( assert isinstance(
datapoints.Label([0, 1, 0], categories=["foo", "bar"]), proto_datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
torch.Tensor, torch.Tensor,
) )
def test_wrapping_no_copy(): def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
assert label.data_ptr() == tensor.data_ptr() assert label.data_ptr() == tensor.data_ptr()
def test_to_wrapping(): def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
label_to = label.to(torch.int32) label_to = label.to(torch.int32)
assert type(label_to) is datapoints.Label assert type(label_to) is proto_datapoints.Label
assert label_to.dtype is torch.int32 assert label_to.dtype is torch.int32
assert label_to.categories is label.categories assert label_to.categories is label.categories
def test_to_datapoint_reference(): def test_to_datapoint_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32) label = proto_datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
tensor_to = tensor.to(label) tensor_to = tensor.to(label)
...@@ -65,31 +65,31 @@ def test_to_datapoint_reference(): ...@@ -65,31 +65,31 @@ def test_to_datapoint_reference():
def test_clone_wrapping(): def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
label_clone = label.clone() label_clone = label.clone()
assert type(label_clone) is datapoints.Label assert type(label_clone) is proto_datapoints.Label
assert label_clone.data_ptr() != label.data_ptr() assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories assert label_clone.categories is label.categories
def test_requires_grad__wrapping(): def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32) tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
assert not label.requires_grad assert not label.requires_grad
label_requires_grad = label.requires_grad_(True) label_requires_grad = label.requires_grad_(True)
assert type(label_requires_grad) is datapoints.Label assert type(label_requires_grad) is proto_datapoints.Label
assert label.requires_grad assert label.requires_grad
assert label_requires_grad.requires_grad assert label_requires_grad.requires_grad
def test_other_op_no_wrapping(): def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here # any operation besides .to() and .clone() will do here
output = label * 2 output = label * 2
...@@ -107,33 +107,33 @@ def test_other_op_no_wrapping(): ...@@ -107,33 +107,33 @@ def test_other_op_no_wrapping():
) )
def test_no_tensor_output_op_no_wrapping(op): def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
output = op(label) output = op(label)
assert type(output) is not datapoints.Label assert type(output) is not proto_datapoints.Label
def test_inplace_op_no_wrapping(): def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
output = label.add_(0) output = label.add_(0)
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
assert type(label) is datapoints.Label assert type(label) is proto_datapoints.Label
def test_wrap_like(): def test_wrap_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here # any operation besides .to() and .clone() will do here
output = label * 2 output = label * 2
label_new = datapoints.Label.wrap_like(label, output) label_new = proto_datapoints.Label.wrap_like(label, output)
assert type(label_new) is datapoints.Label assert type(label_new) is proto_datapoints.Label
assert label_new.data_ptr() == output.data_ptr() assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories assert label_new.categories is label.categories
......
...@@ -5,8 +5,8 @@ from pathlib import Path ...@@ -5,8 +5,8 @@ from pathlib import Path
import pytest import pytest
import torch import torch
import torchvision.transforms.v2 as transforms
import torchvision.prototype.transforms.utils
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair
...@@ -19,10 +19,13 @@ from torch.utils.data.graph_settings import get_all_graph_pipes ...@@ -19,10 +19,13 @@ from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.dataloader2.graph.utils import traverse_dps from torchdata.dataloader2.graph.utils import traverse_dps
from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
from torchvision import datapoints
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, datasets, transforms from torchvision.prototype import datasets
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_simple_tensor
def assert_samples_equal(*args, msg=None, **kwargs): def assert_samples_equal(*args, msg=None, **kwargs):
...@@ -141,9 +144,7 @@ class TestCommon: ...@@ -141,9 +144,7 @@ class TestCommon:
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset)) sample = next_consume(iter(dataset))
simple_tensors = { simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)}
key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
}
if simple_tensors and not any( if simple_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values() isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
...@@ -276,6 +277,6 @@ class TestUSPS: ...@@ -276,6 +277,6 @@ class TestUSPS:
assert "label" in sample assert "label" in sample
assert isinstance(sample["image"], datapoints.Image) assert isinstance(sample["image"], datapoints.Image)
assert isinstance(sample["label"], datapoints.Label) assert isinstance(sample["label"], Label)
assert sample["image"].shape == (1, 16, 16) assert sample["image"].shape == (1, 16, 16)
This diff is collapsed.
...@@ -11,7 +11,6 @@ import pytest ...@@ -11,7 +11,6 @@ import pytest
import torch import torch
import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import ( from prototype_common_utils import (
assert_close, assert_close,
...@@ -22,11 +21,12 @@ from prototype_common_utils import ( ...@@ -22,11 +21,12 @@ from prototype_common_utils import (
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import clamp_bounding_box, convert_format_bounding_box
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_box, convert_format_bounding_box
from torchvision.transforms.v2.utils import is_simple_tensor
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS} KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
...@@ -168,11 +168,7 @@ class TestKernels: ...@@ -168,11 +168,7 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device): def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device) (batched_input, *other_args), kwargs = args_kwargs.load(device)
datapoint_type = ( datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
datapoints.Image
if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input)
else type(batched_input)
)
# This dictionary contains the number of rightmost dimensions that contain the actual data. # This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension. # Everything to the left is considered a batch dimension.
data_dims = { data_dims = {
......
...@@ -3,12 +3,12 @@ import pytest ...@@ -3,12 +3,12 @@ import pytest
import torch import torch
import torchvision.prototype.transforms.utils import torchvision.transforms.v2.utils
from prototype_common_utils import make_bounding_box, make_detection_mask, make_image from prototype_common_utils import make_bounding_box, make_detection_mask, make_image
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.transforms.v2.functional import to_image_pil
from torchvision.prototype.transforms.utils import has_all, has_any from torchvision.transforms.v2.utils import has_all, has_any
IMAGE = make_image(color_space="RGB") IMAGE = make_image(color_space="RGB")
...@@ -37,15 +37,15 @@ MASK = make_detection_mask(size=IMAGE.spatial_size) ...@@ -37,15 +37,15 @@ MASK = make_detection_mask(size=IMAGE.spatial_size)
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), True), ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True),
( (
(torch.Tensor(IMAGE),), (torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
True, True,
), ),
( (
(to_image_pil(IMAGE),), (to_image_pil(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
True, True,
), ),
], ],
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
...@@ -105,7 +105,7 @@ class Datapoint(torch.Tensor): ...@@ -105,7 +105,7 @@ class Datapoint(torch.Tensor):
# the class. This approach avoids the DataLoader issue described at # the class. This approach avoids the DataLoader issue described at
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621 # https://github.com/pytorch/vision/pull/6476#discussion_r953588621
if Datapoint.__F is None: if Datapoint.__F is None:
from ..transforms import functional from ..transforms.v2 import functional
Datapoint.__F = functional Datapoint.__F = functional
return Datapoint.__F return Datapoint.__F
......
...@@ -8,9 +8,8 @@ from collections import defaultdict ...@@ -8,9 +8,8 @@ from collections import defaultdict
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import datasets from torchvision import datapoints, datasets
from torchvision.prototype import datapoints from torchvision.transforms.v2 import functional as F
from torchvision.prototype.transforms import functional as F
__all__ = ["wrap_dataset_for_transforms_v2"] __all__ = ["wrap_dataset_for_transforms_v2"]
......
...@@ -24,7 +24,7 @@ class Image(Datapoint): ...@@ -24,7 +24,7 @@ class Image(Datapoint):
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
) -> Image: ) -> Image:
if isinstance(data, PIL.Image.Image): if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F from torchvision.transforms.v2 import functional as F
data = F.pil_to_tensor(data) data = F.pil_to_tensor(data)
......
...@@ -23,7 +23,7 @@ class Mask(Datapoint): ...@@ -23,7 +23,7 @@ class Mask(Datapoint):
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
) -> Mask: ) -> Mask:
if isinstance(data, PIL.Image.Image): if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F from torchvision.transforms.v2 import functional as F
data = F.pil_to_tensor(data) data = F.pil_to_tensor(data)
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
...@@ -5,7 +5,7 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union ...@@ -5,7 +5,7 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from ._datapoint import Datapoint from torchvision.datapoints._datapoint import Datapoint
L = TypeVar("L", bound="_LabelBase") L = TypeVar("L", bound="_LabelBase")
......
...@@ -6,7 +6,8 @@ import numpy as np ...@@ -6,7 +6,8 @@ import numpy as np
import torch import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment