Unverified Commit 272e080c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add initial chunk of prototype transforms (#4861)

* add initial chunk of prototype transforms

* fix tests

* add error message

* fix more imports

* add explicit no-ops

* add test for no-ops

* cleanup
parent 57e6e302
...@@ -2,7 +2,7 @@ import unittest.mock ...@@ -2,7 +2,7 @@ import unittest.mock
import pytest import pytest
from torchvision.prototype import datasets from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import FrozenMapping, FrozenBunch from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch
def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs): def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs):
......
...@@ -11,6 +11,11 @@ from torchvision.prototype.utils._internal import sequence_to_str ...@@ -11,6 +11,11 @@ from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32) make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32)
def make_image(**kwargs):
data = make_tensor((3, *torch.randint(16, 33, (2,)).tolist()))
return features.Image(data, **kwargs)
def make_bounding_box(*, format="xyxy", image_size=(10, 10)): def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
if isinstance(format, str): if isinstance(format, str):
format = features.BoundingBoxFormat[format] format = features.BoundingBoxFormat[format]
...@@ -42,6 +47,7 @@ def make_bounding_box(*, format="xyxy", image_size=(10, 10)): ...@@ -42,6 +47,7 @@ def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
MAKE_DATA_MAP = { MAKE_DATA_MAP = {
features.Image: make_image,
features.BoundingBox: make_bounding_box, features.BoundingBox: make_bounding_box,
} }
......
import pytest
from torchvision.prototype import transforms, features
from torchvision.prototype.utils._internal import sequence_to_str
FEATURE_TYPES = {
feature_type
for name, feature_type in features.__dict__.items()
if not name.startswith("_")
and isinstance(feature_type, type)
and issubclass(feature_type, features.Feature)
and feature_type is not features.Feature
}
TRANSFORM_TYPES = tuple(
transform_type
for name, transform_type in transforms.__dict__.items()
if not name.startswith("_")
and isinstance(transform_type, type)
and issubclass(transform_type, transforms.Transform)
and transform_type is not transforms.Transform
)
def test_feature_type_support():
missing_feature_types = FEATURE_TYPES - set(transforms.Transform._BUILTIN_FEATURE_TYPES)
if missing_feature_types:
names = sorted([feature_type.__name__ for feature_type in missing_feature_types])
raise AssertionError(
f"The feature(s) {sequence_to_str(names, separate_last='and ')} is/are exposed at "
f"`torchvision.prototype.features`, but are missing in Transform._BUILTIN_FEATURE_TYPES. "
f"Please add it/them to the collection."
)
@pytest.mark.parametrize(
"transform_type",
[transform_type for transform_type in TRANSFORM_TYPES if transform_type is not transforms.Identity],
ids=lambda transform_type: transform_type.__name__,
)
def test_no_op(transform_type):
unsupported_features = (
FEATURE_TYPES - transform_type.supported_feature_types() - set(transform_type.NO_OP_FEATURE_TYPES)
)
if unsupported_features:
names = sorted([feature_type.__name__ for feature_type in unsupported_features])
raise AssertionError(
f"The feature(s) {sequence_to_str(names, separate_last='and ')} are neither supported nor declared as "
f"no-op for transform `{transform_type.__name__}`. Please either implement a feature transform for them, "
f"or add them to the the `{transform_type.__name__}.NO_OP_FEATURE_TYPES` collection."
)
...@@ -20,8 +20,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -20,8 +20,8 @@ from torchvision.prototype.datasets.utils._internal import (
Enumerator, Enumerator,
getitem, getitem,
read_mat, read_mat,
FrozenMapping,
) )
from torchvision.prototype.utils._internal import FrozenMapping
class ImageNet(Dataset): class ImageNet(Dataset):
......
from . import _internal from . import _internal
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery
from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource
...@@ -9,10 +9,11 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple ...@@ -9,10 +9,11 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple
import torch import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torchvision.prototype.utils._internal import FrozenBunch, make_repr
from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str
from .._home import use_sharded_dataset from .._home import use_sharded_dataset
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe from ._internal import BUILTIN_DIR, _make_sharded_datapipe
from ._resource import OnlineResource from ._resource import OnlineResource
......
import csv
import enum import enum
import gzip import gzip
import io import io
...@@ -7,7 +6,6 @@ import os ...@@ -7,7 +6,6 @@ import os
import os.path import os.path
import pathlib import pathlib
import pickle import pickle
import textwrap
from typing import ( from typing import (
Sequence, Sequence,
Callable, Callable,
...@@ -18,10 +16,7 @@ from typing import ( ...@@ -18,10 +16,7 @@ from typing import (
Iterator, Iterator,
Dict, Dict,
Optional, Optional,
NoReturn,
IO, IO,
Iterable,
Mapping,
Sized, Sized,
) )
from typing import cast from typing import cast
...@@ -38,10 +33,6 @@ from torchdata.datapipes.utils import StreamWrapper ...@@ -38,10 +33,6 @@ from torchdata.datapipes.utils import StreamWrapper
__all__ = [ __all__ = [
"INFINITE_BUFFER_SIZE", "INFINITE_BUFFER_SIZE",
"BUILTIN_DIR", "BUILTIN_DIR",
"make_repr",
"FrozenMapping",
"FrozenBunch",
"create_categories_file",
"read_mat", "read_mat",
"image_buffer_from_array", "image_buffer_from_array",
"SequenceIterator", "SequenceIterator",
...@@ -62,82 +53,6 @@ INFINITE_BUFFER_SIZE = 1_000_000_000 ...@@ -62,82 +53,6 @@ INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
prefix = f"{name}("
postfix = ")"
body = to_str(", ")
line_length = int(os.environ.get("COLUMNS", 80))
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
multiline_body = len(str(body).splitlines()) > 1
if not (body_too_long or multiline_body):
return prefix + body + postfix
body = textwrap.indent(to_str(",\n"), " " * 2)
return f"{prefix}\n{body}\n{postfix}"
class FrozenMapping(Mapping[K, D]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
data = dict(*args, **kwargs)
self.__dict__["__data__"] = data
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
def __getitem__(self, item: K) -> D:
return cast(Mapping[K, D], self.__dict__["__data__"])[item]
def __iter__(self) -> Iterator[K]:
return iter(self.__dict__["__data__"].keys())
def __len__(self) -> int:
return len(self.__dict__["__data__"])
def __setitem__(self, key: K, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delitem__(self, key: K) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __hash__(self) -> int:
return cast(int, self.__dict__["__final_hash__"])
def __eq__(self, other: Any) -> bool:
if not isinstance(other, FrozenMapping):
return NotImplemented
return hash(self) == hash(other)
def __repr__(self) -> str:
return repr(self.__dict__["__data__"])
class FrozenBunch(FrozenMapping):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError as error:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
def __setattr__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __delattr__(self, item: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
def __repr__(self) -> str:
return make_repr(type(self).__name__, self.items())
def create_categories_file(
root: Union[str, pathlib.Path], name: str, categories: Sequence[Union[str, Sequence[str]]], **fmtparams: Any
) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w", newline="") as file:
csv.writer(file, **fmtparams).writerows(categories)
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try: try:
import scipy.io as sio import scipy.io as sio
......
import collections.abc
from typing import Any, Callable, Iterator, Optional, Tuple, TypeVar, cast
from torchvision.prototype.features import BoundingBox, Image
T = TypeVar("T")
class SampleQuery:
def __init__(self, sample: Any) -> None:
self.sample = sample
@staticmethod
def _query_recursively(sample: Any, fn: Callable[[Any], Optional[T]]) -> Iterator[T]:
if isinstance(sample, (collections.abc.Sequence, collections.abc.Mapping)):
for item in sample.values() if isinstance(sample, collections.abc.Mapping) else sample:
yield from SampleQuery._query_recursively(item, fn)
else:
result = fn(sample)
if result is not None:
yield result
def query(self, fn: Callable[[Any], Optional[T]]) -> T:
results = set(self._query_recursively(self.sample, fn))
if not results:
raise RuntimeError("Query turned up empty.")
elif len(results) > 1:
raise RuntimeError(f"Found more than one result: {results}")
return results.pop()
def image_size(self) -> Tuple[int, int]:
def fn(sample: Any) -> Optional[Tuple[int, int]]:
if isinstance(sample, Image):
return cast(Tuple[int, int], sample.shape[-2:])
elif isinstance(sample, BoundingBox):
return sample.image_size
else:
return None
return self.query(fn)
...@@ -18,6 +18,15 @@ class Image(Feature): ...@@ -18,6 +18,15 @@ class Image(Feature):
color_spaces = ColorSpace color_spaces = ColorSpace
color_space: ColorSpace color_space: ColorSpace
@classmethod
def _to_tensor(cls, data, *, dtype, device):
tensor = torch.as_tensor(data, dtype=dtype, device=device)
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 3:
raise ValueError("Only single images with 2 or 3 dimensions are allowed.")
return tensor
@classmethod @classmethod
def _parse_meta_data( def _parse_meta_data(
cls, cls,
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet from ...models.alexnet import AlexNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -4,10 +4,10 @@ from functools import partial ...@@ -4,10 +4,10 @@ from functools import partial
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
import torch.nn as nn import torch.nn as nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet from ...models.densenet import DenseNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
import warnings import warnings
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.faster_rcnn import ( from ....models.detection.faster_rcnn import (
...@@ -12,7 +13,6 @@ from ....models.detection.faster_rcnn import ( ...@@ -12,7 +13,6 @@ from ....models.detection.faster_rcnn import (
misc_nn_ops, misc_nn_ops,
overwrite_eps, overwrite_eps,
) )
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
......
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval
from ....models.detection.keypoint_rcnn import ( from ....models.detection.keypoint_rcnn import (
_resnet_fpn_extractor, _resnet_fpn_extractor,
_validate_trainable_layers, _validate_trainable_layers,
...@@ -8,7 +10,6 @@ from ....models.detection.keypoint_rcnn import ( ...@@ -8,7 +10,6 @@ from ....models.detection.keypoint_rcnn import (
misc_nn_ops, misc_nn_ops,
overwrite_eps, overwrite_eps,
) )
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50Weights, resnet50
......
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.mask_rcnn import ( from ....models.detection.mask_rcnn import (
...@@ -10,7 +11,6 @@ from ....models.detection.mask_rcnn import ( ...@@ -10,7 +11,6 @@ from ....models.detection.mask_rcnn import (
misc_nn_ops, misc_nn_ops,
overwrite_eps, overwrite_eps,
) )
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50Weights, resnet50
......
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.retinanet import ( from ....models.detection.retinanet import (
...@@ -11,7 +12,6 @@ from ....models.detection.retinanet import ( ...@@ -11,7 +12,6 @@ from ....models.detection.retinanet import (
misc_nn_ops, misc_nn_ops,
overwrite_eps, overwrite_eps,
) )
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50Weights, resnet50
......
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.ssd import ( from ....models.detection.ssd import (
...@@ -9,7 +10,6 @@ from ....models.detection.ssd import ( ...@@ -9,7 +10,6 @@ from ....models.detection.ssd import (
DefaultBoxGenerator, DefaultBoxGenerator,
SSD, SSD,
) )
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from ..vgg import VGG16Weights, vgg16 from ..vgg import VGG16Weights, vgg16
......
...@@ -3,6 +3,7 @@ from functools import partial ...@@ -3,6 +3,7 @@ from functools import partial
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from torch import nn from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.ssdlite import ( from ....models.detection.ssdlite import (
...@@ -14,7 +15,6 @@ from ....models.detection.ssdlite import ( ...@@ -14,7 +15,6 @@ from ....models.detection.ssdlite import (
SSD, SSD,
SSDLiteHead, SSDLiteHead,
) )
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
......
...@@ -3,10 +3,10 @@ from functools import partial ...@@ -3,10 +3,10 @@ from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torch import nn from torch import nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig from ...models.efficientnet import EfficientNet, MBConvConfig
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment