Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
......@@ -2,9 +2,9 @@ import pathlib
from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
......
......@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Union
import torch
from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
......
......@@ -2,7 +2,6 @@ from pathlib import Path
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
......
......@@ -2,8 +2,6 @@ import pathlib
from typing import Any, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......@@ -11,6 +9,8 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
path_comparator,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
......
......@@ -15,7 +15,6 @@ from torchdata.datapipes.iter import (
TarArchiveLoader,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, ManualDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -26,6 +25,7 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file,
read_mat,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
......
......@@ -7,11 +7,11 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence
import torch
from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE
from torchvision.prototype.tv_tensors import Label
from torchvision.prototype.utils._internal import fromfile
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
......
......@@ -3,7 +3,6 @@ import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -14,6 +13,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
......
......@@ -4,10 +4,10 @@ from collections import namedtuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
......
......@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Tuple, Union
import torch
from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import OneHotLabel
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import OneHotLabel
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
......
......@@ -2,8 +2,6 @@ import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......@@ -12,6 +10,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file,
read_mat,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
......
......@@ -3,10 +3,10 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np
from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
......
......@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Union
import torch
from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
......
......@@ -5,9 +5,7 @@ from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
from xml.etree import ElementTree
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.datapoints import BoundingBoxes
from torchvision.datasets import VOCDetection
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -18,6 +16,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
......
......@@ -5,9 +5,9 @@ import pathlib
from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedData, EncodedImage
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
__all__ = ["from_data_folder", "from_image_folder"]
......
......@@ -6,14 +6,14 @@ from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torchvision.datapoints._datapoint import Datapoint
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
from torchvision.tv_tensors._tv_tensor import TVTensor
D = TypeVar("D", bound="EncodedData")
class EncodedData(Datapoint):
class EncodedData(TVTensor):
@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
......
......@@ -3,9 +3,9 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints as proto_datapoints
from torchvision.prototype import tv_tensors as proto_tv_tensors
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2._utils import is_pure_tensor
......@@ -26,9 +26,9 @@ class SimpleCopyPaste(Transform):
def _copy_paste(
self,
image: Union[torch.Tensor, datapoints.Image],
image: Union[torch.Tensor, tv_tensors.Image],
target: Dict[str, Any],
paste_image: Union[torch.Tensor, datapoints.Image],
paste_image: Union[torch.Tensor, tv_tensors.Image],
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
......@@ -36,9 +36,9 @@ class SimpleCopyPaste(Transform):
antialias: Optional[bool],
) -> Tuple[torch.Tensor, Dict[str, Any]]:
paste_masks = datapoints.wrap(paste_target["masks"][random_selection], like=paste_target["masks"])
paste_boxes = datapoints.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"])
paste_labels = datapoints.wrap(paste_target["labels"][random_selection], like=paste_target["labels"])
paste_masks = tv_tensors.wrap(paste_target["masks"][random_selection], like=paste_target["masks"])
paste_boxes = tv_tensors.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"])
paste_labels = tv_tensors.wrap(paste_target["labels"][random_selection], like=paste_target["labels"])
masks = target["masks"]
......@@ -81,7 +81,7 @@ class SimpleCopyPaste(Transform):
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
xyxy_boxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
......@@ -90,7 +90,7 @@ class SimpleCopyPaste(Transform):
# Check for degenerated boxes and remove them
boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY
out_target["boxes"], old_format=bbox_format, new_format=tv_tensors.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
......@@ -104,20 +104,20 @@ class SimpleCopyPaste(Transform):
def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[Union[torch.Tensor, datapoints.Image]], List[Dict[str, Any]]]:
) -> Tuple[List[Union[torch.Tensor, tv_tensors.Image]], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, datapoints.Image) or is_pure_tensor(obj):
if isinstance(obj, tv_tensors.Image) or is_pure_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image(obj))
elif isinstance(obj, datapoints.BoundingBoxes):
elif isinstance(obj, tv_tensors.BoundingBoxes):
bboxes.append(obj)
elif isinstance(obj, datapoints.Mask):
elif isinstance(obj, tv_tensors.Mask):
masks.append(obj)
elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)):
elif isinstance(obj, (proto_tv_tensors.Label, proto_tv_tensors.OneHotLabel)):
labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
......@@ -140,8 +140,8 @@ class SimpleCopyPaste(Transform):
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, datapoints.Image):
flat_sample[i] = datapoints.wrap(output_images[c0], like=obj)
if isinstance(obj, tv_tensors.Image):
flat_sample[i] = tv_tensors.wrap(output_images[c0], like=obj)
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_pil_image(output_images[c0])
......@@ -149,14 +149,14 @@ class SimpleCopyPaste(Transform):
elif is_pure_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, datapoints.BoundingBoxes):
flat_sample[i] = datapoints.wrap(output_targets[c1]["boxes"], like=obj)
elif isinstance(obj, tv_tensors.BoundingBoxes):
flat_sample[i] = tv_tensors.wrap(output_targets[c1]["boxes"], like=obj)
c1 += 1
elif isinstance(obj, datapoints.Mask):
flat_sample[i] = datapoints.wrap(output_targets[c2]["masks"], like=obj)
elif isinstance(obj, tv_tensors.Mask):
flat_sample[i] = tv_tensors.wrap(output_targets[c2]["masks"], like=obj)
c2 += 1
elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)):
flat_sample[i] = datapoints.wrap(output_targets[c3]["labels"], like=obj)
elif isinstance(obj, (proto_tv_tensors.Label, proto_tv_tensors.OneHotLabel)):
flat_sample[i] = tv_tensors.wrap(output_targets[c3]["labels"], like=obj)
c3 += 1
def forward(self, *inputs: Any) -> Any:
......
......@@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Sequence, Type, Union
import PIL.Image
import torch
from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision import tv_tensors
from torchvision.prototype.tv_tensors import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import (
_FillType,
......@@ -39,15 +39,15 @@ class FixedSizeCrop(Transform):
if not has_any(
flat_inputs,
PIL.Image.Image,
datapoints.Image,
tv_tensors.Image,
is_pure_tensor,
datapoints.Video,
tv_tensors.Video,
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(flat_inputs, datapoints.BoundingBoxes) and not has_any(flat_inputs, Label, OneHotLabel):
if has_any(flat_inputs, tv_tensors.BoundingBoxes) and not has_any(flat_inputs, Label, OneHotLabel):
raise TypeError(
f"If a BoundingBoxes is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
......@@ -85,7 +85,7 @@ class FixedSizeCrop(Transform):
)
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size)
height_and_width = F.convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYWH
)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
else:
......@@ -119,10 +119,10 @@ class FixedSizeCrop(Transform):
)
if params["is_valid"] is not None:
if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)):
inpt = datapoints.wrap(inpt[params["is_valid"]], like=inpt)
elif isinstance(inpt, datapoints.BoundingBoxes):
inpt = datapoints.wrap(
if isinstance(inpt, (Label, OneHotLabel, tv_tensors.Mask)):
inpt = tv_tensors.wrap(inpt[params["is_valid"]], like=inpt)
elif isinstance(inpt, tv_tensors.BoundingBoxes):
inpt = tv_tensors.wrap(
F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size),
like=inpt,
)
......
......@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2._utils import is_pure_tensor
......@@ -25,17 +25,17 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
class PermuteDimensions(Transform):
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
if torch.Tensor in dims and any(cls in dims for cls in [tv_tensors.Image, tv_tensors.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self.dims = dims
......@@ -47,17 +47,17 @@ class PermuteDimensions(Transform):
class TransposeDimensions(Transform):
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
if torch.Tensor in dims and any(cls in dims for cls in [tv_tensors.Image, tv_tensors.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
"in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
)
self.dims = dims
......
......@@ -4,23 +4,23 @@ import torch
from torch.nn.functional import one_hot
from torchvision.prototype import datapoints as proto_datapoints
from torchvision.prototype import tv_tensors as proto_tv_tensors
from torchvision.transforms.v2 import Transform
class LabelToOneHot(Transform):
_transformed_types = (proto_datapoints.Label,)
_transformed_types = (proto_tv_tensors.Label,)
def __init__(self, num_categories: int = -1):
super().__init__()
self.num_categories = num_categories
def _transform(self, inpt: proto_datapoints.Label, params: Dict[str, Any]) -> proto_datapoints.OneHotLabel:
def _transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel:
num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return proto_datapoints.OneHotLabel(output, categories=inpt.categories)
return proto_tv_tensors.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str:
if self.num_categories == -1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment