Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
...@@ -2,9 +2,9 @@ import pathlib ...@@ -2,9 +2,9 @@ import pathlib
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper from torchdata.datapipes.iter import IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Union ...@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Union
import torch import torch
from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -2,7 +2,6 @@ from pathlib import Path ...@@ -2,7 +2,6 @@ from pathlib import Path
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator, path_comparator,
read_categories_file, read_categories_file,
) )
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -2,8 +2,6 @@ import pathlib ...@@ -2,8 +2,6 @@ import pathlib
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
...@@ -11,6 +9,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -11,6 +9,8 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
path_comparator, path_comparator,
) )
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -15,7 +15,6 @@ from torchdata.datapipes.iter import ( ...@@ -15,7 +15,6 @@ from torchdata.datapipes.iter import (
TarArchiveLoader, TarArchiveLoader,
) )
from torchdata.datapipes.map import IterToMapConverter from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, ManualDownloadResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, ManualDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -26,6 +25,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -26,6 +25,7 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file, read_categories_file,
read_mat, read_mat,
) )
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -7,11 +7,11 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence ...@@ -7,11 +7,11 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence
import torch import torch
from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE
from torchvision.prototype.tv_tensors import Label
from torchvision.prototype.utils._internal import fromfile from torchvision.prototype.utils._internal import fromfile
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -3,7 +3,6 @@ import pathlib ...@@ -3,7 +3,6 @@ import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -14,6 +13,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -14,6 +13,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator, path_comparator,
read_categories_file, read_categories_file,
) )
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -4,10 +4,10 @@ from collections import namedtuple ...@@ -4,10 +4,10 @@ from collections import namedtuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Tuple, Union ...@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Tuple, Union
import torch import torch
from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import OneHotLabel
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import OneHotLabel
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -2,8 +2,6 @@ import pathlib ...@@ -2,8 +2,6 @@ import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
...@@ -12,6 +10,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -12,6 +10,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file, read_categories_file,
read_mat, read_mat,
) )
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -3,10 +3,10 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union ...@@ -3,10 +3,10 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np import numpy as np
from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Union ...@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Union
import torch import torch
from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -5,9 +5,7 @@ from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union ...@@ -5,9 +5,7 @@ from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
from xml.etree import ElementTree from xml.etree import ElementTree
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.datapoints import BoundingBoxes
from torchvision.datasets import VOCDetection from torchvision.datasets import VOCDetection
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -18,6 +16,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -18,6 +16,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator, path_comparator,
read_categories_file, read_categories_file,
) )
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -5,9 +5,9 @@ import pathlib ...@@ -5,9 +5,9 @@ import pathlib
from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedData, EncodedImage from torchvision.prototype.datasets.utils import EncodedData, EncodedImage
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
__all__ = ["from_data_folder", "from_image_folder"] __all__ = ["from_data_folder", "from_image_folder"]
......
...@@ -6,14 +6,14 @@ from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union ...@@ -6,14 +6,14 @@ from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.datapoints._datapoint import Datapoint
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
from torchvision.tv_tensors._tv_tensor import TVTensor
D = TypeVar("D", bound="EncodedData") D = TypeVar("D", bound="EncodedData")
class EncodedData(Datapoint): class EncodedData(TVTensor):
@classmethod @classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls) return tensor.as_subclass(cls)
......
...@@ -3,9 +3,9 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union ...@@ -3,9 +3,9 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.ops import masks_to_boxes from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints as proto_datapoints from torchvision.prototype import tv_tensors as proto_tv_tensors
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2._utils import is_pure_tensor from torchvision.transforms.v2._utils import is_pure_tensor
...@@ -26,9 +26,9 @@ class SimpleCopyPaste(Transform): ...@@ -26,9 +26,9 @@ class SimpleCopyPaste(Transform):
def _copy_paste( def _copy_paste(
self, self,
image: Union[torch.Tensor, datapoints.Image], image: Union[torch.Tensor, tv_tensors.Image],
target: Dict[str, Any], target: Dict[str, Any],
paste_image: Union[torch.Tensor, datapoints.Image], paste_image: Union[torch.Tensor, tv_tensors.Image],
paste_target: Dict[str, Any], paste_target: Dict[str, Any],
random_selection: torch.Tensor, random_selection: torch.Tensor,
blending: bool, blending: bool,
...@@ -36,9 +36,9 @@ class SimpleCopyPaste(Transform): ...@@ -36,9 +36,9 @@ class SimpleCopyPaste(Transform):
antialias: Optional[bool], antialias: Optional[bool],
) -> Tuple[torch.Tensor, Dict[str, Any]]: ) -> Tuple[torch.Tensor, Dict[str, Any]]:
paste_masks = datapoints.wrap(paste_target["masks"][random_selection], like=paste_target["masks"]) paste_masks = tv_tensors.wrap(paste_target["masks"][random_selection], like=paste_target["masks"])
paste_boxes = datapoints.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"]) paste_boxes = tv_tensors.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"])
paste_labels = datapoints.wrap(paste_target["labels"][random_selection], like=paste_target["labels"]) paste_labels = tv_tensors.wrap(paste_target["labels"][random_selection], like=paste_target["labels"])
masks = target["masks"] masks = target["masks"]
...@@ -81,7 +81,7 @@ class SimpleCopyPaste(Transform): ...@@ -81,7 +81,7 @@ class SimpleCopyPaste(Transform):
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1 xyxy_boxes[:, 2:] += 1
boxes = F.convert_bounding_box_format( boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True xyxy_boxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
) )
out_target["boxes"] = torch.cat([boxes, paste_boxes]) out_target["boxes"] = torch.cat([boxes, paste_boxes])
...@@ -90,7 +90,7 @@ class SimpleCopyPaste(Transform): ...@@ -90,7 +90,7 @@ class SimpleCopyPaste(Transform):
# Check for degenerated boxes and remove them # Check for degenerated boxes and remove them
boxes = F.convert_bounding_box_format( boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY out_target["boxes"], old_format=bbox_format, new_format=tv_tensors.BoundingBoxFormat.XYXY
) )
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any(): if degenerate_boxes.any():
...@@ -104,20 +104,20 @@ class SimpleCopyPaste(Transform): ...@@ -104,20 +104,20 @@ class SimpleCopyPaste(Transform):
def _extract_image_targets( def _extract_image_targets(
self, flat_sample: List[Any] self, flat_sample: List[Any]
) -> Tuple[List[Union[torch.Tensor, datapoints.Image]], List[Dict[str, Any]]]: ) -> Tuple[List[Union[torch.Tensor, tv_tensors.Image]], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input # fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label] # with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], [] images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample: for obj in flat_sample:
if isinstance(obj, datapoints.Image) or is_pure_tensor(obj): if isinstance(obj, tv_tensors.Image) or is_pure_tensor(obj):
images.append(obj) images.append(obj)
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image(obj)) images.append(F.to_image(obj))
elif isinstance(obj, datapoints.BoundingBoxes): elif isinstance(obj, tv_tensors.BoundingBoxes):
bboxes.append(obj) bboxes.append(obj)
elif isinstance(obj, datapoints.Mask): elif isinstance(obj, tv_tensors.Mask):
masks.append(obj) masks.append(obj)
elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)): elif isinstance(obj, (proto_tv_tensors.Label, proto_tv_tensors.OneHotLabel)):
labels.append(obj) labels.append(obj)
if not (len(images) == len(bboxes) == len(masks) == len(labels)): if not (len(images) == len(bboxes) == len(masks) == len(labels)):
...@@ -140,8 +140,8 @@ class SimpleCopyPaste(Transform): ...@@ -140,8 +140,8 @@ class SimpleCopyPaste(Transform):
) -> None: ) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0 c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample): for i, obj in enumerate(flat_sample):
if isinstance(obj, datapoints.Image): if isinstance(obj, tv_tensors.Image):
flat_sample[i] = datapoints.wrap(output_images[c0], like=obj) flat_sample[i] = tv_tensors.wrap(output_images[c0], like=obj)
c0 += 1 c0 += 1
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_pil_image(output_images[c0]) flat_sample[i] = F.to_pil_image(output_images[c0])
...@@ -149,14 +149,14 @@ class SimpleCopyPaste(Transform): ...@@ -149,14 +149,14 @@ class SimpleCopyPaste(Transform):
elif is_pure_tensor(obj): elif is_pure_tensor(obj):
flat_sample[i] = output_images[c0] flat_sample[i] = output_images[c0]
c0 += 1 c0 += 1
elif isinstance(obj, datapoints.BoundingBoxes): elif isinstance(obj, tv_tensors.BoundingBoxes):
flat_sample[i] = datapoints.wrap(output_targets[c1]["boxes"], like=obj) flat_sample[i] = tv_tensors.wrap(output_targets[c1]["boxes"], like=obj)
c1 += 1 c1 += 1
elif isinstance(obj, datapoints.Mask): elif isinstance(obj, tv_tensors.Mask):
flat_sample[i] = datapoints.wrap(output_targets[c2]["masks"], like=obj) flat_sample[i] = tv_tensors.wrap(output_targets[c2]["masks"], like=obj)
c2 += 1 c2 += 1
elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)): elif isinstance(obj, (proto_tv_tensors.Label, proto_tv_tensors.OneHotLabel)):
flat_sample[i] = datapoints.wrap(output_targets[c3]["labels"], like=obj) flat_sample[i] = tv_tensors.wrap(output_targets[c3]["labels"], like=obj)
c3 += 1 c3 += 1
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
......
...@@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Sequence, Type, Union ...@@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Sequence, Type, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.prototype.tv_tensors import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import ( from torchvision.transforms.v2._utils import (
_FillType, _FillType,
...@@ -39,15 +39,15 @@ class FixedSizeCrop(Transform): ...@@ -39,15 +39,15 @@ class FixedSizeCrop(Transform):
if not has_any( if not has_any(
flat_inputs, flat_inputs,
PIL.Image.Image, PIL.Image.Image,
datapoints.Image, tv_tensors.Image,
is_pure_tensor, is_pure_tensor,
datapoints.Video, tv_tensors.Video,
): ):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video." f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
) )
if has_any(flat_inputs, datapoints.BoundingBoxes) and not has_any(flat_inputs, Label, OneHotLabel): if has_any(flat_inputs, tv_tensors.BoundingBoxes) and not has_any(flat_inputs, Label, OneHotLabel):
raise TypeError( raise TypeError(
f"If a BoundingBoxes is contained in the input sample, " f"If a BoundingBoxes is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
...@@ -85,7 +85,7 @@ class FixedSizeCrop(Transform): ...@@ -85,7 +85,7 @@ class FixedSizeCrop(Transform):
) )
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size) bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size)
height_and_width = F.convert_bounding_box_format( height_and_width = F.convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYWH
)[..., 2:] )[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1) is_valid = torch.all(height_and_width > 0, dim=-1)
else: else:
...@@ -119,10 +119,10 @@ class FixedSizeCrop(Transform): ...@@ -119,10 +119,10 @@ class FixedSizeCrop(Transform):
) )
if params["is_valid"] is not None: if params["is_valid"] is not None:
if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)): if isinstance(inpt, (Label, OneHotLabel, tv_tensors.Mask)):
inpt = datapoints.wrap(inpt[params["is_valid"]], like=inpt) inpt = tv_tensors.wrap(inpt[params["is_valid"]], like=inpt)
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, tv_tensors.BoundingBoxes):
inpt = datapoints.wrap( inpt = tv_tensors.wrap(
F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size), F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size),
like=inpt, like=inpt,
) )
......
...@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union ...@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.v2 import Transform from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2._utils import is_pure_tensor from torchvision.transforms.v2._utils import is_pure_tensor
...@@ -25,17 +25,17 @@ def _get_defaultdict(default: T) -> Dict[Any, T]: ...@@ -25,17 +25,17 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
class PermuteDimensions(Transform): class PermuteDimensions(Transform):
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video) _transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__() super().__init__()
if not isinstance(dims, dict): if not isinstance(dims, dict):
dims = _get_defaultdict(dims) dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): if torch.Tensor in dims and any(cls in dims for cls in [tv_tensors.Image, tv_tensors.Video]):
warnings.warn( warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " "Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input." "in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
) )
self.dims = dims self.dims = dims
...@@ -47,17 +47,17 @@ class PermuteDimensions(Transform): ...@@ -47,17 +47,17 @@ class PermuteDimensions(Transform):
class TransposeDimensions(Transform): class TransposeDimensions(Transform):
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video) _transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__() super().__init__()
if not isinstance(dims, dict): if not isinstance(dims, dict):
dims = _get_defaultdict(dims) dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): if torch.Tensor in dims and any(cls in dims for cls in [tv_tensors.Image, tv_tensors.Video]):
warnings.warn( warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " "Got `dims` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input." "in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input."
) )
self.dims = dims self.dims = dims
......
...@@ -4,23 +4,23 @@ import torch ...@@ -4,23 +4,23 @@ import torch
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torchvision.prototype import datapoints as proto_datapoints from torchvision.prototype import tv_tensors as proto_tv_tensors
from torchvision.transforms.v2 import Transform from torchvision.transforms.v2 import Transform
class LabelToOneHot(Transform): class LabelToOneHot(Transform):
_transformed_types = (proto_datapoints.Label,) _transformed_types = (proto_tv_tensors.Label,)
def __init__(self, num_categories: int = -1): def __init__(self, num_categories: int = -1):
super().__init__() super().__init__()
self.num_categories = num_categories self.num_categories = num_categories
def _transform(self, inpt: proto_datapoints.Label, params: Dict[str, Any]) -> proto_datapoints.OneHotLabel: def _transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel:
num_categories = self.num_categories num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None: if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories) num_categories = len(inpt.categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return proto_datapoints.OneHotLabel(output, categories=inpt.categories) return proto_tv_tensors.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str: def extra_repr(self) -> str:
if self.num_categories == -1: if self.num_categories == -1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment