"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "7be02cbc0bca55b510641c2aef184eb24c98bdb1"
Unverified Commit a8007dcd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

rename features._Feature to datapoints._Datapoint (#7002)

* rename features._Feature to datapoints.Datapoint

* _Datapoint to Datapoint

* move is_simple_tensor to transforms.utils

* fix CI

* move Datapoint out of public namespace
parent c093b9c0
...@@ -43,14 +43,14 @@ jobs: ...@@ -43,14 +43,14 @@ jobs:
id: setup id: setup
run: exit 0 run: exit 0
- name: Run prototype features tests - name: Run prototype datapoints tests
shell: bash shell: bash
run: | run: |
pytest \ pytest \
--durations=20 \ --durations=20 \
--cov=torchvision/prototype/features \ --cov=torchvision/prototype/datapoints \
--cov-report=term-missing \ --cov-report=term-missing \
test/test_prototype_features*.py test/test_prototype_datapoints*.py
- name: Run prototype transforms tests - name: Run prototype transforms tests
if: success() || ( failure() && steps.setup.conclusion == 'success' ) if: success() || ( failure() && steps.setup.conclusion == 'success' )
......
...@@ -15,7 +15,7 @@ import torch.testing ...@@ -15,7 +15,7 @@ import torch.testing
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value from torchvision.transforms.functional_tensor import _max_value as get_max_value
...@@ -238,7 +238,7 @@ class TensorLoader: ...@@ -238,7 +238,7 @@ class TensorLoader:
@dataclasses.dataclass @dataclasses.dataclass
class ImageLoader(TensorLoader): class ImageLoader(TensorLoader):
color_space: features.ColorSpace color_space: datapoints.ColorSpace
spatial_size: Tuple[int, int] = dataclasses.field(init=False) spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False)
...@@ -248,10 +248,10 @@ class ImageLoader(TensorLoader): ...@@ -248,10 +248,10 @@ class ImageLoader(TensorLoader):
NUM_CHANNELS_MAP = { NUM_CHANNELS_MAP = {
features.ColorSpace.GRAY: 1, datapoints.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2, datapoints.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3, datapoints.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4, datapoints.ColorSpace.RGB_ALPHA: 4,
} }
...@@ -265,7 +265,7 @@ def get_num_channels(color_space): ...@@ -265,7 +265,7 @@ def get_num_channels(color_space):
def make_image_loader( def make_image_loader(
size="random", size="random",
*, *,
color_space=features.ColorSpace.RGB, color_space=datapoints.ColorSpace.RGB,
extra_dims=(), extra_dims=(),
dtype=torch.float32, dtype=torch.float32,
constant_alpha=True, constant_alpha=True,
...@@ -276,9 +276,9 @@ def make_image_loader( ...@@ -276,9 +276,9 @@ def make_image_loader(
def fn(shape, dtype, device): def fn(shape, dtype, device):
max_value = get_max_value(dtype) max_value = get_max_value(dtype)
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha: if color_space in {datapoints.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.RGB_ALPHA} and constant_alpha:
data[..., -1, :, :] = max_value data[..., -1, :, :] = max_value
return features.Image(data, color_space=color_space) return datapoints.Image(data, color_space=color_space)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space) return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space)
...@@ -290,10 +290,10 @@ def make_image_loaders( ...@@ -290,10 +290,10 @@ def make_image_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=( color_spaces=(
features.ColorSpace.GRAY, datapoints.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB, datapoints.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA, datapoints.ColorSpace.RGB_ALPHA,
), ),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.float32, torch.uint8), dtypes=(torch.float32, torch.uint8),
...@@ -306,7 +306,7 @@ def make_image_loaders( ...@@ -306,7 +306,7 @@ def make_image_loaders(
make_images = from_loaders(make_image_loaders) make_images = from_loaders(make_image_loaders)
def make_image_loader_for_interpolation(size="random", *, color_space=features.ColorSpace.RGB, dtype=torch.uint8): def make_image_loader_for_interpolation(size="random", *, color_space=datapoints.ColorSpace.RGB, dtype=torch.uint8):
size = _parse_spatial_size(size) size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
...@@ -318,24 +318,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=features.C ...@@ -318,24 +318,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=features.C
.resize((width, height)) .resize((width, height))
.convert( .convert(
{ {
features.ColorSpace.GRAY: "L", datapoints.ColorSpace.GRAY: "L",
features.ColorSpace.GRAY_ALPHA: "LA", datapoints.ColorSpace.GRAY_ALPHA: "LA",
features.ColorSpace.RGB: "RGB", datapoints.ColorSpace.RGB: "RGB",
features.ColorSpace.RGB_ALPHA: "RGBA", datapoints.ColorSpace.RGB_ALPHA: "RGBA",
}[color_space] }[color_space]
) )
) )
image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype) image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
return features.Image(image_tensor, color_space=color_space) return datapoints.Image(image_tensor, color_space=color_space)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space) return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space)
def make_image_loaders_for_interpolation( def make_image_loaders_for_interpolation(
sizes=((233, 147),), sizes=((233, 147),),
color_spaces=(features.ColorSpace.RGB,), color_spaces=(datapoints.ColorSpace.RGB,),
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
): ):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
...@@ -344,7 +344,7 @@ def make_image_loaders_for_interpolation( ...@@ -344,7 +344,7 @@ def make_image_loaders_for_interpolation(
@dataclasses.dataclass @dataclasses.dataclass
class BoundingBoxLoader(TensorLoader): class BoundingBoxLoader(TensorLoader):
format: features.BoundingBoxFormat format: datapoints.BoundingBoxFormat
spatial_size: Tuple[int, int] spatial_size: Tuple[int, int]
...@@ -362,11 +362,11 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): ...@@ -362,11 +362,11 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32): def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32):
if isinstance(format, str): if isinstance(format, str):
format = features.BoundingBoxFormat[format] format = datapoints.BoundingBoxFormat[format]
if format not in { if format not in {
features.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYXY,
features.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.XYWH,
features.BoundingBoxFormat.CXCYWH, datapoints.BoundingBoxFormat.CXCYWH,
}: }:
raise pytest.UsageError(f"Can't make bounding box in format {format}") raise pytest.UsageError(f"Can't make bounding box in format {format}")
...@@ -378,19 +378,19 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt ...@@ -378,19 +378,19 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt
raise pytest.UsageError() raise pytest.UsageError()
if any(dim == 0 for dim in extra_dims): if any(dim == 0 for dim in extra_dims):
return features.BoundingBox( return datapoints.BoundingBox(
torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
height, width = spatial_size height, width = spatial_size
if format == features.BoundingBoxFormat.XYXY: if format == datapoints.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims) x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims) y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2) parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH: elif format == datapoints.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims) x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims) y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x) w = randint_with_tensor_bounds(1, width - x)
...@@ -403,7 +403,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt ...@@ -403,7 +403,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h) parts = (cx, cy, w, h)
return features.BoundingBox( return datapoints.BoundingBox(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
...@@ -416,7 +416,7 @@ make_bounding_box = from_loader(make_bounding_box_loader) ...@@ -416,7 +416,7 @@ make_bounding_box = from_loader(make_bounding_box_loader)
def make_bounding_box_loaders( def make_bounding_box_loaders(
*, *,
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(features.BoundingBoxFormat), formats=tuple(datapoints.BoundingBoxFormat),
spatial_size="random", spatial_size="random",
dtypes=(torch.float32, torch.int64), dtypes=(torch.float32, torch.int64),
): ):
...@@ -456,7 +456,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64): ...@@ -456,7 +456,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return features.Label(data, categories=categories) return datapoints.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
...@@ -480,7 +480,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int ...@@ -480,7 +480,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int
# since `one_hot` only supports int64 # since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype) data = one_hot(label, num_classes=num_categories).to(dtype)
return features.OneHotLabel(data, categories=categories) return datapoints.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
...@@ -509,7 +509,7 @@ def make_detection_mask_loader(size="random", *, num_objects="random", extra_dim ...@@ -509,7 +509,7 @@ def make_detection_mask_loader(size="random", *, num_objects="random", extra_dim
def fn(shape, dtype, device): def fn(shape, dtype, device):
data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device) data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device)
return features.Mask(data) return datapoints.Mask(data)
return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype) return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype)
...@@ -537,7 +537,7 @@ def make_segmentation_mask_loader(size="random", *, num_categories="random", ext ...@@ -537,7 +537,7 @@ def make_segmentation_mask_loader(size="random", *, num_categories="random", ext
def fn(shape, dtype, device): def fn(shape, dtype, device):
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device) data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device)
return features.Mask(data) return datapoints.Mask(data)
return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype) return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype)
...@@ -583,7 +583,7 @@ class VideoLoader(ImageLoader): ...@@ -583,7 +583,7 @@ class VideoLoader(ImageLoader):
def make_video_loader( def make_video_loader(
size="random", size="random",
*, *,
color_space=features.ColorSpace.RGB, color_space=datapoints.ColorSpace.RGB,
num_frames="random", num_frames="random",
extra_dims=(), extra_dims=(),
dtype=torch.uint8, dtype=torch.uint8,
...@@ -593,7 +593,7 @@ def make_video_loader( ...@@ -593,7 +593,7 @@ def make_video_loader(
def fn(shape, dtype, device): def fn(shape, dtype, device):
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device) video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device)
return features.Video(video, color_space=color_space) return datapoints.Video(video, color_space=color_space)
return VideoLoader( return VideoLoader(
fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space
...@@ -607,8 +607,8 @@ def make_video_loaders( ...@@ -607,8 +607,8 @@ def make_video_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=( color_spaces=(
features.ColorSpace.GRAY, datapoints.ColorSpace.GRAY,
features.ColorSpace.RGB, datapoints.ColorSpace.RGB,
), ),
num_frames=(1, 0, "random"), num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
......
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from prototype_common_utils import InfoBase, TestMark from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torchvision.prototype import features from torchvision.prototype import datapoints
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
...@@ -139,20 +139,20 @@ DISPATCHER_INFOS = [ ...@@ -139,20 +139,20 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.horizontal_flip, F.horizontal_flip,
kernels={ kernels={
features.Image: F.horizontal_flip_image_tensor, datapoints.Image: F.horizontal_flip_image_tensor,
features.Video: F.horizontal_flip_video, datapoints.Video: F.horizontal_flip_video,
features.BoundingBox: F.horizontal_flip_bounding_box, datapoints.BoundingBox: F.horizontal_flip_bounding_box,
features.Mask: F.horizontal_flip_mask, datapoints.Mask: F.horizontal_flip_mask,
}, },
pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"), pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.resize, F.resize,
kernels={ kernels={
features.Image: F.resize_image_tensor, datapoints.Image: F.resize_image_tensor,
features.Video: F.resize_video, datapoints.Video: F.resize_video,
features.BoundingBox: F.resize_bounding_box, datapoints.BoundingBox: F.resize_bounding_box,
features.Mask: F.resize_mask, datapoints.Mask: F.resize_mask,
}, },
pil_kernel_info=PILKernelInfo(F.resize_image_pil), pil_kernel_info=PILKernelInfo(F.resize_image_pil),
test_marks=[ test_marks=[
...@@ -162,10 +162,10 @@ DISPATCHER_INFOS = [ ...@@ -162,10 +162,10 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.affine, F.affine,
kernels={ kernels={
features.Image: F.affine_image_tensor, datapoints.Image: F.affine_image_tensor,
features.Video: F.affine_video, datapoints.Video: F.affine_video,
features.BoundingBox: F.affine_bounding_box, datapoints.BoundingBox: F.affine_bounding_box,
features.Mask: F.affine_mask, datapoints.Mask: F.affine_mask,
}, },
pil_kernel_info=PILKernelInfo(F.affine_image_pil), pil_kernel_info=PILKernelInfo(F.affine_image_pil),
test_marks=[ test_marks=[
...@@ -179,20 +179,20 @@ DISPATCHER_INFOS = [ ...@@ -179,20 +179,20 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.vertical_flip, F.vertical_flip,
kernels={ kernels={
features.Image: F.vertical_flip_image_tensor, datapoints.Image: F.vertical_flip_image_tensor,
features.Video: F.vertical_flip_video, datapoints.Video: F.vertical_flip_video,
features.BoundingBox: F.vertical_flip_bounding_box, datapoints.BoundingBox: F.vertical_flip_bounding_box,
features.Mask: F.vertical_flip_mask, datapoints.Mask: F.vertical_flip_mask,
}, },
pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"), pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.rotate, F.rotate,
kernels={ kernels={
features.Image: F.rotate_image_tensor, datapoints.Image: F.rotate_image_tensor,
features.Video: F.rotate_video, datapoints.Video: F.rotate_video,
features.BoundingBox: F.rotate_bounding_box, datapoints.BoundingBox: F.rotate_bounding_box,
features.Mask: F.rotate_mask, datapoints.Mask: F.rotate_mask,
}, },
pil_kernel_info=PILKernelInfo(F.rotate_image_pil), pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
test_marks=[ test_marks=[
...@@ -204,30 +204,30 @@ DISPATCHER_INFOS = [ ...@@ -204,30 +204,30 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.crop, F.crop,
kernels={ kernels={
features.Image: F.crop_image_tensor, datapoints.Image: F.crop_image_tensor,
features.Video: F.crop_video, datapoints.Video: F.crop_video,
features.BoundingBox: F.crop_bounding_box, datapoints.BoundingBox: F.crop_bounding_box,
features.Mask: F.crop_mask, datapoints.Mask: F.crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.resized_crop, F.resized_crop,
kernels={ kernels={
features.Image: F.resized_crop_image_tensor, datapoints.Image: F.resized_crop_image_tensor,
features.Video: F.resized_crop_video, datapoints.Video: F.resized_crop_video,
features.BoundingBox: F.resized_crop_bounding_box, datapoints.BoundingBox: F.resized_crop_bounding_box,
features.Mask: F.resized_crop_mask, datapoints.Mask: F.resized_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil), pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.pad, F.pad,
kernels={ kernels={
features.Image: F.pad_image_tensor, datapoints.Image: F.pad_image_tensor,
features.Video: F.pad_video, datapoints.Video: F.pad_video,
features.BoundingBox: F.pad_bounding_box, datapoints.BoundingBox: F.pad_bounding_box,
features.Mask: F.pad_mask, datapoints.Mask: F.pad_mask,
}, },
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[ test_marks=[
...@@ -251,10 +251,10 @@ DISPATCHER_INFOS = [ ...@@ -251,10 +251,10 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.perspective, F.perspective,
kernels={ kernels={
features.Image: F.perspective_image_tensor, datapoints.Image: F.perspective_image_tensor,
features.Video: F.perspective_video, datapoints.Video: F.perspective_video,
features.BoundingBox: F.perspective_bounding_box, datapoints.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask, datapoints.Mask: F.perspective_mask,
}, },
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[ test_marks=[
...@@ -264,20 +264,20 @@ DISPATCHER_INFOS = [ ...@@ -264,20 +264,20 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.elastic, F.elastic,
kernels={ kernels={
features.Image: F.elastic_image_tensor, datapoints.Image: F.elastic_image_tensor,
features.Video: F.elastic_video, datapoints.Video: F.elastic_video,
features.BoundingBox: F.elastic_bounding_box, datapoints.BoundingBox: F.elastic_bounding_box,
features.Mask: F.elastic_mask, datapoints.Mask: F.elastic_mask,
}, },
pil_kernel_info=PILKernelInfo(F.elastic_image_pil), pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
kernels={ kernels={
features.Image: F.center_crop_image_tensor, datapoints.Image: F.center_crop_image_tensor,
features.Video: F.center_crop_video, datapoints.Video: F.center_crop_video,
features.BoundingBox: F.center_crop_bounding_box, datapoints.BoundingBox: F.center_crop_bounding_box,
features.Mask: F.center_crop_mask, datapoints.Mask: F.center_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
test_marks=[ test_marks=[
...@@ -287,8 +287,8 @@ DISPATCHER_INFOS = [ ...@@ -287,8 +287,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.gaussian_blur, F.gaussian_blur,
kernels={ kernels={
features.Image: F.gaussian_blur_image_tensor, datapoints.Image: F.gaussian_blur_image_tensor,
features.Video: F.gaussian_blur_video, datapoints.Video: F.gaussian_blur_video,
}, },
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil), pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
test_marks=[ test_marks=[
...@@ -299,56 +299,56 @@ DISPATCHER_INFOS = [ ...@@ -299,56 +299,56 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.equalize, F.equalize,
kernels={ kernels={
features.Image: F.equalize_image_tensor, datapoints.Image: F.equalize_image_tensor,
features.Video: F.equalize_video, datapoints.Video: F.equalize_video,
}, },
pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"), pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.invert, F.invert,
kernels={ kernels={
features.Image: F.invert_image_tensor, datapoints.Image: F.invert_image_tensor,
features.Video: F.invert_video, datapoints.Video: F.invert_video,
}, },
pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"), pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.posterize, F.posterize,
kernels={ kernels={
features.Image: F.posterize_image_tensor, datapoints.Image: F.posterize_image_tensor,
features.Video: F.posterize_video, datapoints.Video: F.posterize_video,
}, },
pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"), pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.solarize, F.solarize,
kernels={ kernels={
features.Image: F.solarize_image_tensor, datapoints.Image: F.solarize_image_tensor,
features.Video: F.solarize_video, datapoints.Video: F.solarize_video,
}, },
pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"), pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.autocontrast, F.autocontrast,
kernels={ kernels={
features.Image: F.autocontrast_image_tensor, datapoints.Image: F.autocontrast_image_tensor,
features.Video: F.autocontrast_video, datapoints.Video: F.autocontrast_video,
}, },
pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"), pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_sharpness, F.adjust_sharpness,
kernels={ kernels={
features.Image: F.adjust_sharpness_image_tensor, datapoints.Image: F.adjust_sharpness_image_tensor,
features.Video: F.adjust_sharpness_video, datapoints.Video: F.adjust_sharpness_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.erase, F.erase,
kernels={ kernels={
features.Image: F.erase_image_tensor, datapoints.Image: F.erase_image_tensor,
features.Video: F.erase_video, datapoints.Video: F.erase_video,
}, },
pil_kernel_info=PILKernelInfo(F.erase_image_pil), pil_kernel_info=PILKernelInfo(F.erase_image_pil),
test_marks=[ test_marks=[
...@@ -358,48 +358,48 @@ DISPATCHER_INFOS = [ ...@@ -358,48 +358,48 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.adjust_brightness, F.adjust_brightness,
kernels={ kernels={
features.Image: F.adjust_brightness_image_tensor, datapoints.Image: F.adjust_brightness_image_tensor,
features.Video: F.adjust_brightness_video, datapoints.Video: F.adjust_brightness_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_contrast, F.adjust_contrast,
kernels={ kernels={
features.Image: F.adjust_contrast_image_tensor, datapoints.Image: F.adjust_contrast_image_tensor,
features.Video: F.adjust_contrast_video, datapoints.Video: F.adjust_contrast_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_gamma, F.adjust_gamma,
kernels={ kernels={
features.Image: F.adjust_gamma_image_tensor, datapoints.Image: F.adjust_gamma_image_tensor,
features.Video: F.adjust_gamma_video, datapoints.Video: F.adjust_gamma_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_hue, F.adjust_hue,
kernels={ kernels={
features.Image: F.adjust_hue_image_tensor, datapoints.Image: F.adjust_hue_image_tensor,
features.Video: F.adjust_hue_video, datapoints.Video: F.adjust_hue_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_saturation, F.adjust_saturation,
kernels={ kernels={
features.Image: F.adjust_saturation_image_tensor, datapoints.Image: F.adjust_saturation_image_tensor,
features.Video: F.adjust_saturation_video, datapoints.Video: F.adjust_saturation_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.five_crop, F.five_crop,
kernels={ kernels={
features.Image: F.five_crop_image_tensor, datapoints.Image: F.five_crop_image_tensor,
features.Video: F.five_crop_video, datapoints.Video: F.five_crop_video,
}, },
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
test_marks=[ test_marks=[
...@@ -410,8 +410,8 @@ DISPATCHER_INFOS = [ ...@@ -410,8 +410,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.ten_crop, F.ten_crop,
kernels={ kernels={
features.Image: F.ten_crop_image_tensor, datapoints.Image: F.ten_crop_image_tensor,
features.Video: F.ten_crop_video, datapoints.Video: F.ten_crop_video,
}, },
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
...@@ -422,8 +422,8 @@ DISPATCHER_INFOS = [ ...@@ -422,8 +422,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.normalize, F.normalize,
kernels={ kernels={
features.Image: F.normalize_image_tensor, datapoints.Image: F.normalize_image_tensor,
features.Video: F.normalize_video, datapoints.Video: F.normalize_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_feature, skip_dispatch_feature,
...@@ -434,8 +434,8 @@ DISPATCHER_INFOS = [ ...@@ -434,8 +434,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.convert_dtype, F.convert_dtype,
kernels={ kernels={
features.Image: F.convert_dtype_image_tensor, datapoints.Image: F.convert_dtype_image_tensor,
features.Video: F.convert_dtype_video, datapoints.Video: F.convert_dtype_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_feature, skip_dispatch_feature,
...@@ -444,7 +444,7 @@ DISPATCHER_INFOS = [ ...@@ -444,7 +444,7 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.uniform_temporal_subsample, F.uniform_temporal_subsample,
kernels={ kernels={
features.Video: F.uniform_temporal_subsample_video, datapoints.Video: F.uniform_temporal_subsample_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_feature, skip_dispatch_feature,
......
This diff is collapsed.
import pytest import pytest
import torch import torch
from torchvision.prototype import features from torchvision.prototype import datapoints
def test_isinstance(): def test_isinstance():
assert isinstance( assert isinstance(
features.Label([0, 1, 0], categories=["foo", "bar"]), datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
torch.Tensor, torch.Tensor,
) )
def test_wrapping_no_copy(): def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
assert label.data_ptr() == tensor.data_ptr() assert label.data_ptr() == tensor.data_ptr()
def test_to_wrapping(): def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
label_to = label.to(torch.int32) label_to = label.to(torch.int32)
assert type(label_to) is features.Label assert type(label_to) is datapoints.Label
assert label_to.dtype is torch.int32 assert label_to.dtype is torch.int32
assert label_to.categories is label.categories assert label_to.categories is label.categories
def test_to_feature_reference(): def test_to_feature_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32) label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
tensor_to = tensor.to(label) tensor_to = tensor.to(label)
...@@ -40,31 +40,31 @@ def test_to_feature_reference(): ...@@ -40,31 +40,31 @@ def test_to_feature_reference():
def test_clone_wrapping(): def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
label_clone = label.clone() label_clone = label.clone()
assert type(label_clone) is features.Label assert type(label_clone) is datapoints.Label
assert label_clone.data_ptr() != label.data_ptr() assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories assert label_clone.categories is label.categories
def test_requires_grad__wrapping(): def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32) tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
assert not label.requires_grad assert not label.requires_grad
label_requires_grad = label.requires_grad_(True) label_requires_grad = label.requires_grad_(True)
assert type(label_requires_grad) is features.Label assert type(label_requires_grad) is datapoints.Label
assert label.requires_grad assert label.requires_grad
assert label_requires_grad.requires_grad assert label_requires_grad.requires_grad
def test_other_op_no_wrapping(): def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here # any operation besides .to() and .clone() will do here
output = label * 2 output = label * 2
...@@ -82,32 +82,32 @@ def test_other_op_no_wrapping(): ...@@ -82,32 +82,32 @@ def test_other_op_no_wrapping():
) )
def test_no_tensor_output_op_no_wrapping(op): def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
output = op(label) output = op(label)
assert type(output) is not features.Label assert type(output) is not datapoints.Label
def test_inplace_op_no_wrapping(): def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
output = label.add_(0) output = label.add_(0)
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
assert type(label) is features.Label assert type(label) is datapoints.Label
def test_wrap_like(): def test_wrap_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here # any operation besides .to() and .clone() will do here
output = label * 2 output = label * 2
label_new = features.Label.wrap_like(label, output) label_new = datapoints.Label.wrap_like(label, output)
assert type(label_new) is features.Label assert type(label_new) is datapoints.Label
assert label_new.data_ptr() == output.data_ptr() assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories assert label_new.categories is label.categories
...@@ -6,6 +6,8 @@ from pathlib import Path ...@@ -6,6 +6,8 @@ from pathlib import Path
import pytest import pytest
import torch import torch
import torchvision.prototype.transforms.utils
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -14,7 +16,7 @@ from torch.utils.data.graph_settings import get_all_graph_pipes ...@@ -14,7 +16,7 @@ from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets, features, transforms from torchvision.prototype import datapoints, datasets, transforms
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
...@@ -130,7 +132,11 @@ class TestCommon: ...@@ -130,7 +132,11 @@ class TestCommon:
def test_no_simple_tensors(self, dataset_mock, config): def test_no_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)} simple_tensors = {
key
for key, value in next_consume(iter(dataset)).items()
if torchvision.prototype.transforms.utils.is_simple_tensor(value)
}
if simple_tensors: if simple_tensors:
raise AssertionError( raise AssertionError(
f"The values of key(s) " f"The values of key(s) "
...@@ -258,7 +264,7 @@ class TestUSPS: ...@@ -258,7 +264,7 @@ class TestUSPS:
assert "image" in sample assert "image" in sample
assert "label" in sample assert "label" in sample
assert isinstance(sample["image"], features.Image) assert isinstance(sample["image"], datapoints.Image)
assert isinstance(sample["label"], features.Label) assert isinstance(sample["label"], datapoints.Label)
assert sample["image"].shape == (1, 16, 16) assert sample["image"].shape == (1, 16, 16)
This diff is collapsed.
...@@ -24,13 +24,13 @@ from prototype_common_utils import ( ...@@ -24,13 +24,13 @@ from prototype_common_utils import (
) )
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import datapoints, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as prototype_F from torchvision.prototype.transforms import functional as prototype_F
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import query_spatial_size from torchvision.prototype.transforms.utils import query_spatial_size
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[datapoints.ColorSpace.RGB], extra_dims=[(4,)])
class ConsistencyConfig: class ConsistencyConfig:
...@@ -138,7 +138,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -138,7 +138,7 @@ CONSISTENCY_CONFIGS = [
# Make sure that the product of the height, width and number of channels matches the number of elements in # Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict( make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[features.ColorSpace.RGB] DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[datapoints.ColorSpace.RGB]
), ),
supports_pil=False, supports_pil=False,
), ),
...@@ -150,7 +150,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -150,7 +150,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(num_output_channels=3), ArgsKwargs(num_output_channels=3),
], ],
make_images_kwargs=dict( make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[features.ColorSpace.RGB, features.ColorSpace.GRAY] DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[datapoints.ColorSpace.RGB, datapoints.ColorSpace.GRAY]
), ),
), ),
ConsistencyConfig( ConsistencyConfig(
...@@ -173,10 +173,10 @@ CONSISTENCY_CONFIGS = [ ...@@ -173,10 +173,10 @@ CONSISTENCY_CONFIGS = [
[ArgsKwargs()], [ArgsKwargs()],
make_images_kwargs=dict( make_images_kwargs=dict(
color_spaces=[ color_spaces=[
features.ColorSpace.GRAY, datapoints.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB, datapoints.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA, datapoints.ColorSpace.RGB_ALPHA,
], ],
extra_dims=[()], extra_dims=[()],
), ),
...@@ -733,7 +733,7 @@ class TestAATransforms: ...@@ -733,7 +733,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -771,7 +771,7 @@ class TestAATransforms: ...@@ -771,7 +771,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -819,7 +819,7 @@ class TestAATransforms: ...@@ -819,7 +819,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -868,7 +868,7 @@ class TestAATransforms: ...@@ -868,7 +868,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -902,7 +902,7 @@ class TestRefDetTransforms: ...@@ -902,7 +902,7 @@ class TestRefDetTransforms:
size = (600, 800) size = (600, 800)
num_objects = 22 num_objects = 22
pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB)) pil_image = to_image_pil(make_image(size=size, color_space=datapoints.ColorSpace.RGB))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -912,7 +912,7 @@ class TestRefDetTransforms: ...@@ -912,7 +912,7 @@ class TestRefDetTransforms:
yield (pil_image, target) yield (pil_image, target)
tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB)) tensor_image = torch.Tensor(make_image(size=size, color_space=datapoints.ColorSpace.RGB))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -922,7 +922,7 @@ class TestRefDetTransforms: ...@@ -922,7 +922,7 @@ class TestRefDetTransforms:
yield (tensor_image, target) yield (tensor_image, target)
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB) feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB)
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -1006,7 +1006,7 @@ class TestRefSegTransforms: ...@@ -1006,7 +1006,7 @@ class TestRefSegTransforms:
conv_fns.extend([torch.Tensor, lambda x: x]) conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns: for conv_fn in conv_fns:
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype) feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype)
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(feature_image), feature_mask) dp = (conv_fn(feature_image), feature_mask)
...@@ -1053,7 +1053,7 @@ class TestRefSegTransforms: ...@@ -1053,7 +1053,7 @@ class TestRefSegTransforms:
seg_transforms.RandomCrop(size=480), seg_transforms.RandomCrop(size=480),
prototype_transforms.Compose( prototype_transforms.Compose(
[ [
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})), PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
prototype_transforms.RandomCrop(size=480), prototype_transforms.RandomCrop(size=480),
] ]
), ),
......
...@@ -10,12 +10,14 @@ import PIL.Image ...@@ -10,12 +10,14 @@ import PIL.Image
import pytest import pytest
import torch import torch
import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box
...@@ -147,18 +149,22 @@ class TestKernels: ...@@ -147,18 +149,22 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device): def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device) (batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input) feature_type = (
datapoints.Image
if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input)
else type(batched_input)
)
# This dictionary contains the number of rightmost dimensions that contain the actual data. # This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension. # Everything to the left is considered a batch dimension.
data_dims = { data_dims = {
features.Image: 3, datapoints.Image: 3,
features.BoundingBox: 1, datapoints.BoundingBox: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground. # common ground.
features.Mask: 2, datapoints.Mask: 2,
features.Video: 4, datapoints.Video: 4,
}.get(feature_type) }.get(feature_type)
if data_dims is None: if data_dims is None:
raise pytest.UsageError( raise pytest.UsageError(
...@@ -281,8 +287,8 @@ def spy_on(mocker): ...@@ -281,8 +287,8 @@ def spy_on(mocker):
class TestDispatchers: class TestDispatchers:
image_sample_inputs = make_info_args_kwargs_parametrization( image_sample_inputs = make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if features.Image in info.kernels], [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
args_kwargs_fn=lambda info: info.sample_inputs(features.Image), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
) )
@ignore_jit_warning_no_profile @ignore_jit_warning_no_profile
...@@ -323,7 +329,7 @@ class TestDispatchers: ...@@ -323,7 +329,7 @@ class TestDispatchers:
(image_feature, *other_args), kwargs = args_kwargs.load() (image_feature, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_feature) image_simple_tensor = torch.Tensor(image_feature)
kernel_info = info.kernel_infos[features.Image] kernel_info = info.kernel_infos[datapoints.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id) spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
info.dispatcher(image_simple_tensor, *other_args, **kwargs) info.dispatcher(image_simple_tensor, *other_args, **kwargs)
...@@ -332,7 +338,7 @@ class TestDispatchers: ...@@ -332,7 +338,7 @@ class TestDispatchers:
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(features.Image), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
) )
def test_dispatch_pil(self, info, args_kwargs, spy_on): def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load() (image_feature, *other_args), kwargs = args_kwargs.load()
...@@ -403,7 +409,7 @@ class TestDispatchers: ...@@ -403,7 +409,7 @@ class TestDispatchers:
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_feature_signatures_consistency(self, info): def test_dispatcher_feature_signatures_consistency(self, info):
try: try:
feature_method = getattr(features._Feature, info.id) feature_method = getattr(datapoints._datapoint.Datapoint, info.id)
except AttributeError: except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.") pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.")
...@@ -413,7 +419,7 @@ class TestDispatchers: ...@@ -413,7 +419,7 @@ class TestDispatchers:
feature_signature = inspect.signature(feature_method) feature_signature = inspect.signature(feature_method)
feature_params = list(feature_signature.parameters.values())[1:] feature_params = list(feature_signature.parameters.values())[1:]
# Because we use `from __future__ import annotations` inside the module where `features._Feature` is defined, # Because we use `from __future__ import annotations` inside the module where `features._datapoint` is defined,
# the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively # the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively
# concrete dispatcher annotations. # concrete dispatcher annotations.
feature_annotations = get_type_hints(feature_method) feature_annotations = get_type_hints(feature_method)
...@@ -505,8 +511,12 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -505,8 +511,12 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
[1, 1, 5, 5], [1, 1, 5, 5],
] ]
in_boxes = features.BoundingBox( in_boxes = datapoints.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, dtype=torch.float64, device=device in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
) )
# Tested parameters # Tested parameters
angle = 63 angle = 63
...@@ -572,7 +582,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -572,7 +582,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
height, width = bbox.spatial_size height, width = bbox.spatial_size
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
) )
points = np.array( points = np.array(
[ [
...@@ -605,15 +615,15 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -605,15 +615,15 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
height = int(height - 2 * tr_y) height = int(height - 2 * tr_y)
width = int(width - 2 * tr_x) width = int(width - 2 * tr_x)
out_bbox = features.BoundingBox( out_bbox = datapoints.BoundingBox(
out_bbox, out_bbox,
format=features.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(height, width), spatial_size=(height, width),
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return ( return (
convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format), convert_format_bounding_box(out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format),
(height, width), (height, width),
) )
...@@ -641,7 +651,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -641,7 +651,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bbox, expected_spatial_size = _compute_expected_bbox(bbox, -angle, expand, center_) expected_bbox, expected_spatial_size = _compute_expected_bbox(bbox, -angle, expand, center_)
expected_bboxes.append(expected_bbox) expected_bboxes.append(expected_bbox)
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -664,8 +674,12 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -664,8 +674,12 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
[spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2], [spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2],
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
] ]
in_boxes = features.BoundingBox( in_boxes = datapoints.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, dtype=torch.float64, device=device in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
) )
# Tested parameters # Tested parameters
angle = 45 angle = 45
...@@ -725,7 +739,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device): ...@@ -725,7 +739,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", "format",
[features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"top, left, height, width, expected_bboxes", "top, left, height, width, expected_bboxes",
...@@ -755,9 +769,11 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -755,9 +769,11 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
[50.0, 5.0, 70.0, 22.0], [50.0, 5.0, 70.0, 22.0],
[45.0, 46.0, 56.0, 62.0], [45.0, 46.0, 56.0, 62.0],
] ]
in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=size, device=device) in_boxes = datapoints.BoundingBox(
if format != features.BoundingBoxFormat.XYXY: in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=size, device=device
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) )
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
output_boxes, output_spatial_size = F.crop_bounding_box( output_boxes, output_spatial_size = F.crop_bounding_box(
in_boxes, in_boxes,
...@@ -768,8 +784,8 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -768,8 +784,8 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
size[1], size[1],
) )
if format != features.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_spatial_size, size) torch.testing.assert_close(output_spatial_size, size)
...@@ -802,7 +818,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): ...@@ -802,7 +818,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", "format",
[features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"top, left, height, width, size", "top, left, height, width, size",
...@@ -831,16 +847,16 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -831,16 +847,16 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size)) expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
expected_bboxes = torch.tensor(expected_bboxes, device=device) expected_bboxes = torch.tensor(expected_bboxes, device=device)
in_boxes = features.BoundingBox( in_boxes = datapoints.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device
) )
if format != features.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
output_boxes, output_spatial_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size) output_boxes, output_spatial_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
if format != features.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_spatial_size, size) torch.testing.assert_close(output_spatial_size, size)
...@@ -868,14 +884,14 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -868,14 +884,14 @@ def test_correctness_pad_bounding_box(device, padding):
bbox_dtype = bbox.dtype bbox_dtype = bbox.dtype
bbox = ( bbox = (
bbox.clone() bbox.clone()
if bbox_format == features.BoundingBoxFormat.XYXY if bbox_format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bbox, bbox_format, features.BoundingBoxFormat.XYXY) else convert_format_bounding_box(bbox, bbox_format, datapoints.BoundingBoxFormat.XYXY)
) )
bbox[0::2] += pad_left bbox[0::2] += pad_left
bbox[1::2] += pad_up bbox[1::2] += pad_up
bbox = convert_format_bounding_box(bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format) bbox = convert_format_bounding_box(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format)
if bbox.dtype != bbox_dtype: if bbox.dtype != bbox_dtype:
# Temporary cast to original dtype # Temporary cast to original dtype
# e.g. float32 -> int # e.g. float32 -> int
...@@ -903,7 +919,7 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -903,7 +919,7 @@ def test_correctness_pad_bounding_box(device, padding):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, padding)) expected_bboxes.append(_compute_expected_bbox(bbox, padding))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -949,7 +965,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -949,7 +965,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
) )
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
) )
points = np.array( points = np.array(
[ [
...@@ -968,14 +984,16 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -968,14 +984,16 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
np.max(transformed_points[:, 0]), np.max(transformed_points[:, 0]),
np.max(transformed_points[:, 1]), np.max(transformed_points[:, 1]),
] ]
out_bbox = features.BoundingBox( out_bbox = datapoints.BoundingBox(
np.array(out_bbox), np.array(out_bbox),
format=features.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=bbox.spatial_size, spatial_size=bbox.spatial_size,
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format) return convert_format_bounding_box(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format
)
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -1000,7 +1018,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1000,7 +1018,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
...@@ -1019,7 +1037,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1019,7 +1037,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
format_ = bbox.format format_ = bbox.format
spatial_size_ = bbox.spatial_size spatial_size_ = bbox.spatial_size
dtype = bbox.dtype dtype = bbox.dtype
bbox = convert_format_bounding_box(bbox.float(), format_, features.BoundingBoxFormat.XYWH) bbox = convert_format_bounding_box(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
if len(output_size_) == 1: if len(output_size_) == 1:
output_size_.append(output_size_[-1]) output_size_.append(output_size_[-1])
...@@ -1033,7 +1051,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1033,7 +1051,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bbox[3].item(), bbox[3].item(),
] ]
out_bbox = torch.tensor(out_bbox) out_bbox = torch.tensor(out_bbox)
out_bbox = convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_) out_bbox = convert_format_bounding_box(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
return out_bbox.to(dtype=dtype, device=bbox.device) return out_bbox.to(dtype=dtype, device=bbox.device)
for bboxes in make_bounding_boxes(extra_dims=((4,),)): for bboxes in make_bounding_boxes(extra_dims=((4,),)):
...@@ -1050,7 +1068,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1050,7 +1068,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) expected_bboxes.append(_compute_expected_bbox(bbox, output_size))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -1135,7 +1153,7 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, ...@@ -1135,7 +1153,7 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize,
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
) )
image = features.Image(tensor) image = datapoints.Image(tensor)
out = fn(image, kernel_size=ksize, sigma=sigma) out = fn(image, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
...@@ -1147,7 +1165,7 @@ def test_normalize_output_type(): ...@@ -1147,7 +1165,7 @@ def test_normalize_output_type():
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output) torch.testing.assert_close(inpt - 0.5, output)
inpt = make_image(color_space=features.ColorSpace.RGB) inpt = make_image(color_space=datapoints.ColorSpace.RGB)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output) torch.testing.assert_close(inpt - 0.5, output)
......
...@@ -3,42 +3,51 @@ import pytest ...@@ -3,42 +3,51 @@ import pytest
import torch import torch
import torchvision.prototype.transforms.utils
from prototype_common_utils import make_bounding_box, make_detection_mask, make_image from prototype_common_utils import make_bounding_box, make_detection_mask, make_image
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import has_all, has_any from torchvision.prototype.transforms.utils import has_all, has_any
IMAGE = make_image(color_space=features.ColorSpace.RGB) IMAGE = make_image(color_space=datapoints.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size)
MASK = make_detection_mask(size=IMAGE.spatial_size) MASK = make_detection_mask(size=IMAGE.spatial_size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True),
((MASK,), (features.Image, features.BoundingBox), False), ((MASK,), (datapoints.Image, datapoints.BoundingBox), False),
((BOUNDING_BOX,), (features.Image, features.Mask), False), ((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False),
((IMAGE,), (features.BoundingBox, features.Mask), False), ((IMAGE,), (datapoints.BoundingBox, datapoints.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.Mask), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask),
True, True,
), ),
((), (features.Image, features.BoundingBox, features.Mask), False), ((), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, features.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), (
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), (torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor),
True,
),
(
(to_image_pil(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor),
True,
),
], ],
) )
def test_has_any(sample, types, expected): def test_has_any(sample, types, expected):
...@@ -48,31 +57,31 @@ def test_has_any(sample, types, expected): ...@@ -48,31 +57,31 @@ def test_has_any(sample, types, expected):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.Mask), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask),
True, True,
), ),
((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), False),
((BOUNDING_BOX, MASK), (features.Image, features.Mask), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False),
((IMAGE, MASK), (features.BoundingBox, features.Mask), False), ((IMAGE, MASK), (datapoints.BoundingBox, datapoints.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.Mask), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask),
True, True,
), ),
((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox, features.Mask), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
((IMAGE, MASK), (features.Image, features.BoundingBox, features.Mask), False), ((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.Mask), False), ((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.Mask)),), (lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBox, datapoints.Mask)),),
True, True,
), ),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
......
from . import features, models, transforms, utils from . import datapoints, models, transforms, utils
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask from ._mask import Mask
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._feature import _Feature, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
class BoundingBoxFormat(StrEnum): class BoundingBoxFormat(StrEnum):
...@@ -15,7 +15,7 @@ class BoundingBoxFormat(StrEnum): ...@@ -15,7 +15,7 @@ class BoundingBoxFormat(StrEnum):
CXCYWH = StrEnum.auto() CXCYWH = StrEnum.auto()
class BoundingBox(_Feature): class BoundingBox(Datapoint):
format: BoundingBoxFormat format: BoundingBoxFormat
spatial_size: Tuple[int, int] spatial_size: Tuple[int, int]
......
...@@ -10,16 +10,12 @@ from torch.types import _device, _dtype, _size ...@@ -10,16 +10,12 @@ from torch.types import _device, _dtype, _size
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature") D = TypeVar("D", bound="Datapoint")
FillType = Union[int, float, Sequence[int], Sequence[float], None] FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None] FillTypeJIT = Union[int, float, List[float], None]
def is_simple_tensor(inpt: Any) -> bool: class Datapoint(torch.Tensor):
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)
class _Feature(torch.Tensor):
__F: Optional[ModuleType] = None __F: Optional[ModuleType] = None
@staticmethod @staticmethod
...@@ -31,22 +27,22 @@ class _Feature(torch.Tensor): ...@@ -31,22 +27,22 @@ class _Feature(torch.Tensor):
) -> torch.Tensor: ) -> torch.Tensor:
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the _Feature directly to have a # FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be # a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
# interpreted as images. We should decide if we want a public no-op feature like `GenericFeature` or make this one # interpreted as images. We should decide if we want a public no-op datapoint like `GenericDatapoint` or make this
# public again. # one public again.
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> _Feature: ) -> Datapoint:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return tensor.as_subclass(_Feature) return tensor.as_subclass(Datapoint)
@classmethod @classmethod
def wrap_like(cls: Type[F], other: F, tensor: torch.Tensor) -> F: def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved, # FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
# this method should be made abstract # this method should be made abstract
# raise NotImplementedError # raise NotImplementedError
...@@ -75,15 +71,15 @@ class _Feature(torch.Tensor): ...@@ -75,15 +71,15 @@ class _Feature(torch.Tensor):
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call. ``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`_Feature` The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
use case, this has two downsides: use case, this has two downsides:
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e. 1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them. ``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output. 2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`~_Feature._NO_WRAPPING_EXCEPTIONS` listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
""" """
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality. # need to reimplement the functionality.
...@@ -98,9 +94,9 @@ class _Feature(torch.Tensor): ...@@ -98,9 +94,9 @@ class _Feature(torch.Tensor):
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(features.Image(...))` will invoke `features.Image.__torch_function__` with # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), features.Image())` first. Without this guard, the original `torch.Tensor` would # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `features.Image`. # be wrapped into a `datapoints.Image`.
if wrapper and isinstance(args[0], cls): if wrapper and isinstance(args[0], cls):
return wrapper(cls, args[0], output) # type: ignore[no-any-return] return wrapper(cls, args[0], output) # type: ignore[no-any-return]
...@@ -123,11 +119,11 @@ class _Feature(torch.Tensor): ...@@ -123,11 +119,11 @@ class _Feature(torch.Tensor):
# until the first time we need reference to the functional module and it's shared across all instances of # until the first time we need reference to the functional module and it's shared across all instances of
# the class. This approach avoids the DataLoader issue described at # the class. This approach avoids the DataLoader issue described at
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621 # https://github.com/pytorch/vision/pull/6476#discussion_r953588621
if _Feature.__F is None: if Datapoint.__F is None:
from ..transforms import functional from ..transforms import functional
_Feature.__F = functional Datapoint.__F = functional
return _Feature.__F return Datapoint.__F
# Add properties for common attributes like shape, dtype, device, ndim etc # Add properties for common attributes like shape, dtype, device, ndim etc
# this way we return the result without passing into __torch_function__ # this way we return the result without passing into __torch_function__
...@@ -151,10 +147,10 @@ class _Feature(torch.Tensor): ...@@ -151,10 +147,10 @@ class _Feature(torch.Tensor):
with DisableTorchFunction(): with DisableTorchFunction():
return super().dtype return super().dtype
def horizontal_flip(self) -> _Feature: def horizontal_flip(self) -> Datapoint:
return self return self
def vertical_flip(self) -> _Feature: def vertical_flip(self) -> Datapoint:
return self return self
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
...@@ -165,13 +161,13 @@ class _Feature(torch.Tensor): ...@@ -165,13 +161,13 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> _Feature: ) -> Datapoint:
return self return self
def crop(self, top: int, left: int, height: int, width: int) -> _Feature: def crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
return self return self
def center_crop(self, output_size: List[int]) -> _Feature: def center_crop(self, output_size: List[int]) -> Datapoint:
return self return self
def resized_crop( def resized_crop(
...@@ -183,7 +179,7 @@ class _Feature(torch.Tensor): ...@@ -183,7 +179,7 @@ class _Feature(torch.Tensor):
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> _Feature: ) -> Datapoint:
return self return self
def pad( def pad(
...@@ -191,7 +187,7 @@ class _Feature(torch.Tensor): ...@@ -191,7 +187,7 @@ class _Feature(torch.Tensor):
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> _Feature: ) -> Datapoint:
return self return self
def rotate( def rotate(
...@@ -201,7 +197,7 @@ class _Feature(torch.Tensor): ...@@ -201,7 +197,7 @@ class _Feature(torch.Tensor):
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> _Feature: ) -> Datapoint:
return self return self
def affine( def affine(
...@@ -213,7 +209,7 @@ class _Feature(torch.Tensor): ...@@ -213,7 +209,7 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> _Feature: ) -> Datapoint:
return self return self
def perspective( def perspective(
...@@ -223,7 +219,7 @@ class _Feature(torch.Tensor): ...@@ -223,7 +219,7 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> _Feature: ) -> Datapoint:
return self return self
def elastic( def elastic(
...@@ -231,45 +227,45 @@ class _Feature(torch.Tensor): ...@@ -231,45 +227,45 @@ class _Feature(torch.Tensor):
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> _Feature: ) -> Datapoint:
return self return self
def adjust_brightness(self, brightness_factor: float) -> _Feature: def adjust_brightness(self, brightness_factor: float) -> Datapoint:
return self return self
def adjust_saturation(self, saturation_factor: float) -> _Feature: def adjust_saturation(self, saturation_factor: float) -> Datapoint:
return self return self
def adjust_contrast(self, contrast_factor: float) -> _Feature: def adjust_contrast(self, contrast_factor: float) -> Datapoint:
return self return self
def adjust_sharpness(self, sharpness_factor: float) -> _Feature: def adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
return self return self
def adjust_hue(self, hue_factor: float) -> _Feature: def adjust_hue(self, hue_factor: float) -> Datapoint:
return self return self
def adjust_gamma(self, gamma: float, gain: float = 1) -> _Feature: def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
return self return self
def posterize(self, bits: int) -> _Feature: def posterize(self, bits: int) -> Datapoint:
return self return self
def solarize(self, threshold: float) -> _Feature: def solarize(self, threshold: float) -> Datapoint:
return self return self
def autocontrast(self) -> _Feature: def autocontrast(self) -> Datapoint:
return self return self
def equalize(self) -> _Feature: def equalize(self) -> Datapoint:
return self return self
def invert(self) -> _Feature: def invert(self) -> Datapoint:
return self return self
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint:
return self return self
InputType = Union[torch.Tensor, PIL.Image.Image, _Feature] InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
InputTypeJIT = torch.Tensor InputTypeJIT = torch.Tensor
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ._feature import _Feature, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
class ColorSpace(StrEnum): class ColorSpace(StrEnum):
...@@ -57,7 +57,7 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: ...@@ -57,7 +57,7 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace:
return ColorSpace.OTHER return ColorSpace.OTHER
class Image(_Feature): class Image(Datapoint):
color_space: ColorSpace color_space: ColorSpace
@classmethod @classmethod
......
...@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union ...@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from ._feature import _Feature from ._datapoint import Datapoint
L = TypeVar("L", bound="_LabelBase") L = TypeVar("L", bound="_LabelBase")
class _LabelBase(_Feature): class _LabelBase(Datapoint):
categories: Optional[Sequence[str]] categories: Optional[Sequence[str]]
@classmethod @classmethod
......
...@@ -5,10 +5,10 @@ from typing import Any, List, Optional, Tuple, Union ...@@ -5,10 +5,10 @@ from typing import Any, List, Optional, Tuple, Union
import torch import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from ._feature import _Feature, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
class Mask(_Feature): class Mask(Datapoint):
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor) -> Mask: def _wrap(cls, tensor: torch.Tensor) -> Mask:
return tensor.as_subclass(cls) return tensor.as_subclass(cls)
......
...@@ -6,11 +6,11 @@ from typing import Any, List, Optional, Tuple, Union ...@@ -6,11 +6,11 @@ from typing import Any, List, Optional, Tuple, Union
import torch import torch
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ._feature import _Feature, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
from ._image import ColorSpace from ._image import ColorSpace
class Video(_Feature): class Video(Datapoint):
color_space: ColorSpace color_space: ColorSpace
@classmethod @classmethod
......
...@@ -4,6 +4,8 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union ...@@ -4,6 +4,8 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
...@@ -12,7 +14,6 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -12,7 +14,6 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file, read_categories_file,
read_mat, read_mat,
) )
from torchvision.prototype.features import _Feature, BoundingBox, Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
...@@ -114,7 +115,7 @@ class Caltech101(Dataset): ...@@ -114,7 +115,7 @@ class Caltech101(Dataset):
format="xyxy", format="xyxy",
spatial_size=image.spatial_size, spatial_size=image.spatial_size,
), ),
contour=_Feature(ann["obj_contour"].T), contour=Datapoint(ann["obj_contour"].T),
) )
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
......
...@@ -3,6 +3,8 @@ import pathlib ...@@ -3,6 +3,8 @@ import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -11,7 +13,6 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -11,7 +13,6 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
path_accessor, path_accessor,
) )
from torchvision.prototype.features import _Feature, BoundingBox, Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
...@@ -148,7 +149,7 @@ class CelebA(Dataset): ...@@ -148,7 +149,7 @@ class CelebA(Dataset):
spatial_size=image.spatial_size, spatial_size=image.spatial_size,
), ),
landmarks={ landmarks={
landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) landmark: Datapoint((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()} for landmark in {key[:-2] for key in landmarks.keys()}
}, },
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment