Unverified Commit a8007dcd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

rename features._Feature to datapoints._Datapoint (#7002)

* rename features._Feature to datapoints.Datapoint

* _Datapoint to Datapoint

* move is_simple_tensor to transforms.utils

* fix CI

* move Datapoint out of public namespace
parent c093b9c0
...@@ -43,14 +43,14 @@ jobs: ...@@ -43,14 +43,14 @@ jobs:
id: setup id: setup
run: exit 0 run: exit 0
- name: Run prototype features tests - name: Run prototype datapoints tests
shell: bash shell: bash
run: | run: |
pytest \ pytest \
--durations=20 \ --durations=20 \
--cov=torchvision/prototype/features \ --cov=torchvision/prototype/datapoints \
--cov-report=term-missing \ --cov-report=term-missing \
test/test_prototype_features*.py test/test_prototype_datapoints*.py
- name: Run prototype transforms tests - name: Run prototype transforms tests
if: success() || ( failure() && steps.setup.conclusion == 'success' ) if: success() || ( failure() && steps.setup.conclusion == 'success' )
......
...@@ -15,7 +15,7 @@ import torch.testing ...@@ -15,7 +15,7 @@ import torch.testing
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value from torchvision.transforms.functional_tensor import _max_value as get_max_value
...@@ -238,7 +238,7 @@ class TensorLoader: ...@@ -238,7 +238,7 @@ class TensorLoader:
@dataclasses.dataclass @dataclasses.dataclass
class ImageLoader(TensorLoader): class ImageLoader(TensorLoader):
color_space: features.ColorSpace color_space: datapoints.ColorSpace
spatial_size: Tuple[int, int] = dataclasses.field(init=False) spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False)
...@@ -248,10 +248,10 @@ class ImageLoader(TensorLoader): ...@@ -248,10 +248,10 @@ class ImageLoader(TensorLoader):
NUM_CHANNELS_MAP = { NUM_CHANNELS_MAP = {
features.ColorSpace.GRAY: 1, datapoints.ColorSpace.GRAY: 1,
features.ColorSpace.GRAY_ALPHA: 2, datapoints.ColorSpace.GRAY_ALPHA: 2,
features.ColorSpace.RGB: 3, datapoints.ColorSpace.RGB: 3,
features.ColorSpace.RGB_ALPHA: 4, datapoints.ColorSpace.RGB_ALPHA: 4,
} }
...@@ -265,7 +265,7 @@ def get_num_channels(color_space): ...@@ -265,7 +265,7 @@ def get_num_channels(color_space):
def make_image_loader( def make_image_loader(
size="random", size="random",
*, *,
color_space=features.ColorSpace.RGB, color_space=datapoints.ColorSpace.RGB,
extra_dims=(), extra_dims=(),
dtype=torch.float32, dtype=torch.float32,
constant_alpha=True, constant_alpha=True,
...@@ -276,9 +276,9 @@ def make_image_loader( ...@@ -276,9 +276,9 @@ def make_image_loader(
def fn(shape, dtype, device): def fn(shape, dtype, device):
max_value = get_max_value(dtype) max_value = get_max_value(dtype)
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
if color_space in {features.ColorSpace.GRAY_ALPHA, features.ColorSpace.RGB_ALPHA} and constant_alpha: if color_space in {datapoints.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.RGB_ALPHA} and constant_alpha:
data[..., -1, :, :] = max_value data[..., -1, :, :] = max_value
return features.Image(data, color_space=color_space) return datapoints.Image(data, color_space=color_space)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space) return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space)
...@@ -290,10 +290,10 @@ def make_image_loaders( ...@@ -290,10 +290,10 @@ def make_image_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=( color_spaces=(
features.ColorSpace.GRAY, datapoints.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB, datapoints.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA, datapoints.ColorSpace.RGB_ALPHA,
), ),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.float32, torch.uint8), dtypes=(torch.float32, torch.uint8),
...@@ -306,7 +306,7 @@ def make_image_loaders( ...@@ -306,7 +306,7 @@ def make_image_loaders(
make_images = from_loaders(make_image_loaders) make_images = from_loaders(make_image_loaders)
def make_image_loader_for_interpolation(size="random", *, color_space=features.ColorSpace.RGB, dtype=torch.uint8): def make_image_loader_for_interpolation(size="random", *, color_space=datapoints.ColorSpace.RGB, dtype=torch.uint8):
size = _parse_spatial_size(size) size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
...@@ -318,24 +318,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=features.C ...@@ -318,24 +318,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=features.C
.resize((width, height)) .resize((width, height))
.convert( .convert(
{ {
features.ColorSpace.GRAY: "L", datapoints.ColorSpace.GRAY: "L",
features.ColorSpace.GRAY_ALPHA: "LA", datapoints.ColorSpace.GRAY_ALPHA: "LA",
features.ColorSpace.RGB: "RGB", datapoints.ColorSpace.RGB: "RGB",
features.ColorSpace.RGB_ALPHA: "RGBA", datapoints.ColorSpace.RGB_ALPHA: "RGBA",
}[color_space] }[color_space]
) )
) )
image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype) image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
return features.Image(image_tensor, color_space=color_space) return datapoints.Image(image_tensor, color_space=color_space)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space) return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space)
def make_image_loaders_for_interpolation( def make_image_loaders_for_interpolation(
sizes=((233, 147),), sizes=((233, 147),),
color_spaces=(features.ColorSpace.RGB,), color_spaces=(datapoints.ColorSpace.RGB,),
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
): ):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
...@@ -344,7 +344,7 @@ def make_image_loaders_for_interpolation( ...@@ -344,7 +344,7 @@ def make_image_loaders_for_interpolation(
@dataclasses.dataclass @dataclasses.dataclass
class BoundingBoxLoader(TensorLoader): class BoundingBoxLoader(TensorLoader):
format: features.BoundingBoxFormat format: datapoints.BoundingBoxFormat
spatial_size: Tuple[int, int] spatial_size: Tuple[int, int]
...@@ -362,11 +362,11 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): ...@@ -362,11 +362,11 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32): def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32):
if isinstance(format, str): if isinstance(format, str):
format = features.BoundingBoxFormat[format] format = datapoints.BoundingBoxFormat[format]
if format not in { if format not in {
features.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYXY,
features.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.XYWH,
features.BoundingBoxFormat.CXCYWH, datapoints.BoundingBoxFormat.CXCYWH,
}: }:
raise pytest.UsageError(f"Can't make bounding box in format {format}") raise pytest.UsageError(f"Can't make bounding box in format {format}")
...@@ -378,19 +378,19 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt ...@@ -378,19 +378,19 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt
raise pytest.UsageError() raise pytest.UsageError()
if any(dim == 0 for dim in extra_dims): if any(dim == 0 for dim in extra_dims):
return features.BoundingBox( return datapoints.BoundingBox(
torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
height, width = spatial_size height, width = spatial_size
if format == features.BoundingBoxFormat.XYXY: if format == datapoints.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims) x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims) y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2) parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH: elif format == datapoints.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims) x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims) y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x) w = randint_with_tensor_bounds(1, width - x)
...@@ -403,7 +403,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt ...@@ -403,7 +403,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dt
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h) parts = (cx, cy, w, h)
return features.BoundingBox( return datapoints.BoundingBox(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
...@@ -416,7 +416,7 @@ make_bounding_box = from_loader(make_bounding_box_loader) ...@@ -416,7 +416,7 @@ make_bounding_box = from_loader(make_bounding_box_loader)
def make_bounding_box_loaders( def make_bounding_box_loaders(
*, *,
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(features.BoundingBoxFormat), formats=tuple(datapoints.BoundingBoxFormat),
spatial_size="random", spatial_size="random",
dtypes=(torch.float32, torch.int64), dtypes=(torch.float32, torch.int64),
): ):
...@@ -456,7 +456,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64): ...@@ -456,7 +456,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return features.Label(data, categories=categories) return datapoints.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
...@@ -480,7 +480,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int ...@@ -480,7 +480,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int
# since `one_hot` only supports int64 # since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype) data = one_hot(label, num_classes=num_categories).to(dtype)
return features.OneHotLabel(data, categories=categories) return datapoints.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
...@@ -509,7 +509,7 @@ def make_detection_mask_loader(size="random", *, num_objects="random", extra_dim ...@@ -509,7 +509,7 @@ def make_detection_mask_loader(size="random", *, num_objects="random", extra_dim
def fn(shape, dtype, device): def fn(shape, dtype, device):
data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device) data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device)
return features.Mask(data) return datapoints.Mask(data)
return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype) return MaskLoader(fn, shape=(*extra_dims, num_objects, *size), dtype=dtype)
...@@ -537,7 +537,7 @@ def make_segmentation_mask_loader(size="random", *, num_categories="random", ext ...@@ -537,7 +537,7 @@ def make_segmentation_mask_loader(size="random", *, num_categories="random", ext
def fn(shape, dtype, device): def fn(shape, dtype, device):
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device) data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device)
return features.Mask(data) return datapoints.Mask(data)
return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype) return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype)
...@@ -583,7 +583,7 @@ class VideoLoader(ImageLoader): ...@@ -583,7 +583,7 @@ class VideoLoader(ImageLoader):
def make_video_loader( def make_video_loader(
size="random", size="random",
*, *,
color_space=features.ColorSpace.RGB, color_space=datapoints.ColorSpace.RGB,
num_frames="random", num_frames="random",
extra_dims=(), extra_dims=(),
dtype=torch.uint8, dtype=torch.uint8,
...@@ -593,7 +593,7 @@ def make_video_loader( ...@@ -593,7 +593,7 @@ def make_video_loader(
def fn(shape, dtype, device): def fn(shape, dtype, device):
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device) video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device)
return features.Video(video, color_space=color_space) return datapoints.Video(video, color_space=color_space)
return VideoLoader( return VideoLoader(
fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space
...@@ -607,8 +607,8 @@ def make_video_loaders( ...@@ -607,8 +607,8 @@ def make_video_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=( color_spaces=(
features.ColorSpace.GRAY, datapoints.ColorSpace.GRAY,
features.ColorSpace.RGB, datapoints.ColorSpace.RGB,
), ),
num_frames=(1, 0, "random"), num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
......
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from prototype_common_utils import InfoBase, TestMark from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torchvision.prototype import features from torchvision.prototype import datapoints
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
...@@ -139,20 +139,20 @@ DISPATCHER_INFOS = [ ...@@ -139,20 +139,20 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.horizontal_flip, F.horizontal_flip,
kernels={ kernels={
features.Image: F.horizontal_flip_image_tensor, datapoints.Image: F.horizontal_flip_image_tensor,
features.Video: F.horizontal_flip_video, datapoints.Video: F.horizontal_flip_video,
features.BoundingBox: F.horizontal_flip_bounding_box, datapoints.BoundingBox: F.horizontal_flip_bounding_box,
features.Mask: F.horizontal_flip_mask, datapoints.Mask: F.horizontal_flip_mask,
}, },
pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"), pil_kernel_info=PILKernelInfo(F.horizontal_flip_image_pil, kernel_name="horizontal_flip_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.resize, F.resize,
kernels={ kernels={
features.Image: F.resize_image_tensor, datapoints.Image: F.resize_image_tensor,
features.Video: F.resize_video, datapoints.Video: F.resize_video,
features.BoundingBox: F.resize_bounding_box, datapoints.BoundingBox: F.resize_bounding_box,
features.Mask: F.resize_mask, datapoints.Mask: F.resize_mask,
}, },
pil_kernel_info=PILKernelInfo(F.resize_image_pil), pil_kernel_info=PILKernelInfo(F.resize_image_pil),
test_marks=[ test_marks=[
...@@ -162,10 +162,10 @@ DISPATCHER_INFOS = [ ...@@ -162,10 +162,10 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.affine, F.affine,
kernels={ kernels={
features.Image: F.affine_image_tensor, datapoints.Image: F.affine_image_tensor,
features.Video: F.affine_video, datapoints.Video: F.affine_video,
features.BoundingBox: F.affine_bounding_box, datapoints.BoundingBox: F.affine_bounding_box,
features.Mask: F.affine_mask, datapoints.Mask: F.affine_mask,
}, },
pil_kernel_info=PILKernelInfo(F.affine_image_pil), pil_kernel_info=PILKernelInfo(F.affine_image_pil),
test_marks=[ test_marks=[
...@@ -179,20 +179,20 @@ DISPATCHER_INFOS = [ ...@@ -179,20 +179,20 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.vertical_flip, F.vertical_flip,
kernels={ kernels={
features.Image: F.vertical_flip_image_tensor, datapoints.Image: F.vertical_flip_image_tensor,
features.Video: F.vertical_flip_video, datapoints.Video: F.vertical_flip_video,
features.BoundingBox: F.vertical_flip_bounding_box, datapoints.BoundingBox: F.vertical_flip_bounding_box,
features.Mask: F.vertical_flip_mask, datapoints.Mask: F.vertical_flip_mask,
}, },
pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"), pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.rotate, F.rotate,
kernels={ kernels={
features.Image: F.rotate_image_tensor, datapoints.Image: F.rotate_image_tensor,
features.Video: F.rotate_video, datapoints.Video: F.rotate_video,
features.BoundingBox: F.rotate_bounding_box, datapoints.BoundingBox: F.rotate_bounding_box,
features.Mask: F.rotate_mask, datapoints.Mask: F.rotate_mask,
}, },
pil_kernel_info=PILKernelInfo(F.rotate_image_pil), pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
test_marks=[ test_marks=[
...@@ -204,30 +204,30 @@ DISPATCHER_INFOS = [ ...@@ -204,30 +204,30 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.crop, F.crop,
kernels={ kernels={
features.Image: F.crop_image_tensor, datapoints.Image: F.crop_image_tensor,
features.Video: F.crop_video, datapoints.Video: F.crop_video,
features.BoundingBox: F.crop_bounding_box, datapoints.BoundingBox: F.crop_bounding_box,
features.Mask: F.crop_mask, datapoints.Mask: F.crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.resized_crop, F.resized_crop,
kernels={ kernels={
features.Image: F.resized_crop_image_tensor, datapoints.Image: F.resized_crop_image_tensor,
features.Video: F.resized_crop_video, datapoints.Video: F.resized_crop_video,
features.BoundingBox: F.resized_crop_bounding_box, datapoints.BoundingBox: F.resized_crop_bounding_box,
features.Mask: F.resized_crop_mask, datapoints.Mask: F.resized_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil), pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.pad, F.pad,
kernels={ kernels={
features.Image: F.pad_image_tensor, datapoints.Image: F.pad_image_tensor,
features.Video: F.pad_video, datapoints.Video: F.pad_video,
features.BoundingBox: F.pad_bounding_box, datapoints.BoundingBox: F.pad_bounding_box,
features.Mask: F.pad_mask, datapoints.Mask: F.pad_mask,
}, },
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[ test_marks=[
...@@ -251,10 +251,10 @@ DISPATCHER_INFOS = [ ...@@ -251,10 +251,10 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.perspective, F.perspective,
kernels={ kernels={
features.Image: F.perspective_image_tensor, datapoints.Image: F.perspective_image_tensor,
features.Video: F.perspective_video, datapoints.Video: F.perspective_video,
features.BoundingBox: F.perspective_bounding_box, datapoints.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask, datapoints.Mask: F.perspective_mask,
}, },
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[ test_marks=[
...@@ -264,20 +264,20 @@ DISPATCHER_INFOS = [ ...@@ -264,20 +264,20 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.elastic, F.elastic,
kernels={ kernels={
features.Image: F.elastic_image_tensor, datapoints.Image: F.elastic_image_tensor,
features.Video: F.elastic_video, datapoints.Video: F.elastic_video,
features.BoundingBox: F.elastic_bounding_box, datapoints.BoundingBox: F.elastic_bounding_box,
features.Mask: F.elastic_mask, datapoints.Mask: F.elastic_mask,
}, },
pil_kernel_info=PILKernelInfo(F.elastic_image_pil), pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
kernels={ kernels={
features.Image: F.center_crop_image_tensor, datapoints.Image: F.center_crop_image_tensor,
features.Video: F.center_crop_video, datapoints.Video: F.center_crop_video,
features.BoundingBox: F.center_crop_bounding_box, datapoints.BoundingBox: F.center_crop_bounding_box,
features.Mask: F.center_crop_mask, datapoints.Mask: F.center_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
test_marks=[ test_marks=[
...@@ -287,8 +287,8 @@ DISPATCHER_INFOS = [ ...@@ -287,8 +287,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.gaussian_blur, F.gaussian_blur,
kernels={ kernels={
features.Image: F.gaussian_blur_image_tensor, datapoints.Image: F.gaussian_blur_image_tensor,
features.Video: F.gaussian_blur_video, datapoints.Video: F.gaussian_blur_video,
}, },
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil), pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
test_marks=[ test_marks=[
...@@ -299,56 +299,56 @@ DISPATCHER_INFOS = [ ...@@ -299,56 +299,56 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.equalize, F.equalize,
kernels={ kernels={
features.Image: F.equalize_image_tensor, datapoints.Image: F.equalize_image_tensor,
features.Video: F.equalize_video, datapoints.Video: F.equalize_video,
}, },
pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"), pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.invert, F.invert,
kernels={ kernels={
features.Image: F.invert_image_tensor, datapoints.Image: F.invert_image_tensor,
features.Video: F.invert_video, datapoints.Video: F.invert_video,
}, },
pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"), pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.posterize, F.posterize,
kernels={ kernels={
features.Image: F.posterize_image_tensor, datapoints.Image: F.posterize_image_tensor,
features.Video: F.posterize_video, datapoints.Video: F.posterize_video,
}, },
pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"), pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.solarize, F.solarize,
kernels={ kernels={
features.Image: F.solarize_image_tensor, datapoints.Image: F.solarize_image_tensor,
features.Video: F.solarize_video, datapoints.Video: F.solarize_video,
}, },
pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"), pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.autocontrast, F.autocontrast,
kernels={ kernels={
features.Image: F.autocontrast_image_tensor, datapoints.Image: F.autocontrast_image_tensor,
features.Video: F.autocontrast_video, datapoints.Video: F.autocontrast_video,
}, },
pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"), pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_sharpness, F.adjust_sharpness,
kernels={ kernels={
features.Image: F.adjust_sharpness_image_tensor, datapoints.Image: F.adjust_sharpness_image_tensor,
features.Video: F.adjust_sharpness_video, datapoints.Video: F.adjust_sharpness_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.erase, F.erase,
kernels={ kernels={
features.Image: F.erase_image_tensor, datapoints.Image: F.erase_image_tensor,
features.Video: F.erase_video, datapoints.Video: F.erase_video,
}, },
pil_kernel_info=PILKernelInfo(F.erase_image_pil), pil_kernel_info=PILKernelInfo(F.erase_image_pil),
test_marks=[ test_marks=[
...@@ -358,48 +358,48 @@ DISPATCHER_INFOS = [ ...@@ -358,48 +358,48 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.adjust_brightness, F.adjust_brightness,
kernels={ kernels={
features.Image: F.adjust_brightness_image_tensor, datapoints.Image: F.adjust_brightness_image_tensor,
features.Video: F.adjust_brightness_video, datapoints.Video: F.adjust_brightness_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_contrast, F.adjust_contrast,
kernels={ kernels={
features.Image: F.adjust_contrast_image_tensor, datapoints.Image: F.adjust_contrast_image_tensor,
features.Video: F.adjust_contrast_video, datapoints.Video: F.adjust_contrast_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_gamma, F.adjust_gamma,
kernels={ kernels={
features.Image: F.adjust_gamma_image_tensor, datapoints.Image: F.adjust_gamma_image_tensor,
features.Video: F.adjust_gamma_video, datapoints.Video: F.adjust_gamma_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_hue, F.adjust_hue,
kernels={ kernels={
features.Image: F.adjust_hue_image_tensor, datapoints.Image: F.adjust_hue_image_tensor,
features.Video: F.adjust_hue_video, datapoints.Video: F.adjust_hue_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_saturation, F.adjust_saturation,
kernels={ kernels={
features.Image: F.adjust_saturation_image_tensor, datapoints.Image: F.adjust_saturation_image_tensor,
features.Video: F.adjust_saturation_video, datapoints.Video: F.adjust_saturation_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"), pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.five_crop, F.five_crop,
kernels={ kernels={
features.Image: F.five_crop_image_tensor, datapoints.Image: F.five_crop_image_tensor,
features.Video: F.five_crop_video, datapoints.Video: F.five_crop_video,
}, },
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
test_marks=[ test_marks=[
...@@ -410,8 +410,8 @@ DISPATCHER_INFOS = [ ...@@ -410,8 +410,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.ten_crop, F.ten_crop,
kernels={ kernels={
features.Image: F.ten_crop_image_tensor, datapoints.Image: F.ten_crop_image_tensor,
features.Video: F.ten_crop_video, datapoints.Video: F.ten_crop_video,
}, },
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
...@@ -422,8 +422,8 @@ DISPATCHER_INFOS = [ ...@@ -422,8 +422,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.normalize, F.normalize,
kernels={ kernels={
features.Image: F.normalize_image_tensor, datapoints.Image: F.normalize_image_tensor,
features.Video: F.normalize_video, datapoints.Video: F.normalize_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_feature, skip_dispatch_feature,
...@@ -434,8 +434,8 @@ DISPATCHER_INFOS = [ ...@@ -434,8 +434,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.convert_dtype, F.convert_dtype,
kernels={ kernels={
features.Image: F.convert_dtype_image_tensor, datapoints.Image: F.convert_dtype_image_tensor,
features.Video: F.convert_dtype_video, datapoints.Video: F.convert_dtype_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_feature, skip_dispatch_feature,
...@@ -444,7 +444,7 @@ DISPATCHER_INFOS = [ ...@@ -444,7 +444,7 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.uniform_temporal_subsample, F.uniform_temporal_subsample,
kernels={ kernels={
features.Video: F.uniform_temporal_subsample_video, datapoints.Video: F.uniform_temporal_subsample_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_feature, skip_dispatch_feature,
......
...@@ -26,7 +26,7 @@ from prototype_common_utils import ( ...@@ -26,7 +26,7 @@ from prototype_common_utils import (
TestMark, TestMark,
) )
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding
__all__ = ["KernelInfo", "KERNEL_INFOS"] __all__ = ["KernelInfo", "KERNEL_INFOS"]
...@@ -176,7 +176,7 @@ def reference_inputs_horizontal_flip_image_tensor(): ...@@ -176,7 +176,7 @@ def reference_inputs_horizontal_flip_image_tensor():
def sample_inputs_horizontal_flip_bounding_box(): def sample_inputs_horizontal_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders( for bounding_box_loader in make_bounding_box_loaders(
formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
): ):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
...@@ -258,13 +258,13 @@ def _get_resize_sizes(spatial_size): ...@@ -258,13 +258,13 @@ def _get_resize_sizes(spatial_size):
def sample_inputs_resize_image_tensor(): def sample_inputs_resize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]
): ):
for size in _get_resize_sizes(image_loader.spatial_size): for size in _get_resize_sizes(image_loader.spatial_size):
yield ArgsKwargs(image_loader, size=size) yield ArgsKwargs(image_loader, size=size)
for image_loader, interpolation in itertools.product( for image_loader, interpolation in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB]), make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB]),
[ [
F.InterpolationMode.NEAREST, F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR, F.InterpolationMode.BILINEAR,
...@@ -468,7 +468,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): ...@@ -468,7 +468,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs):
def sample_inputs_affine_image_tensor(): def sample_inputs_affine_image_tensor():
make_affine_image_loaders = functools.partial( make_affine_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]
) )
for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS): for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS):
...@@ -499,7 +499,7 @@ def reference_inputs_affine_image_tensor(): ...@@ -499,7 +499,7 @@ def reference_inputs_affine_image_tensor():
def sample_inputs_affine_bounding_box(): def sample_inputs_affine_bounding_box():
for bounding_box_loader, affine_params in itertools.product( for bounding_box_loader, affine_params in itertools.product(
make_bounding_box_loaders(formats=[features.BoundingBoxFormat.XYXY]), _DIVERSE_AFFINE_PARAMS make_bounding_box_loaders(formats=[datapoints.BoundingBoxFormat.XYXY]), _DIVERSE_AFFINE_PARAMS
): ):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
...@@ -537,7 +537,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix) ...@@ -537,7 +537,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype in_dtype = bbox.dtype
bbox_xyxy = F.convert_format_bounding_box( bbox_xyxy = F.convert_format_bounding_box(
bbox.float(), old_format=format_, new_format=features.BoundingBoxFormat.XYXY, inplace=True bbox.float(), old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
) )
points = np.array( points = np.array(
[ [
...@@ -557,7 +557,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix) ...@@ -557,7 +557,7 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
], ],
) )
out_bbox = F.convert_format_bounding_box( out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format_, inplace=True out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
) )
return out_bbox.to(dtype=in_dtype) return out_bbox.to(dtype=in_dtype)
...@@ -652,7 +652,7 @@ KERNEL_INFOS.extend( ...@@ -652,7 +652,7 @@ KERNEL_INFOS.extend(
def sample_inputs_convert_format_bounding_box(): def sample_inputs_convert_format_bounding_box():
formats = list(features.BoundingBoxFormat) formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format) yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
...@@ -681,7 +681,7 @@ KERNEL_INFOS.append( ...@@ -681,7 +681,7 @@ KERNEL_INFOS.append(
def sample_inputs_convert_color_space_image_tensor(): def sample_inputs_convert_color_space_image_tensor():
color_spaces = sorted( color_spaces = sorted(
set(features.ColorSpace) - {features.ColorSpace.OTHER}, key=lambda color_space: color_space.value set(datapoints.ColorSpace) - {datapoints.ColorSpace.OTHER}, key=lambda color_space: color_space.value
) )
for old_color_space, new_color_space in cycle_over(color_spaces): for old_color_space, new_color_space in cycle_over(color_spaces):
...@@ -697,7 +697,7 @@ def sample_inputs_convert_color_space_image_tensor(): ...@@ -697,7 +697,7 @@ def sample_inputs_convert_color_space_image_tensor():
@pil_reference_wrapper @pil_reference_wrapper
def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space): def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space):
color_space_pil = features.ColorSpace.from_pil_mode(image_pil.mode) color_space_pil = datapoints.ColorSpace.from_pil_mode(image_pil.mode)
if color_space_pil != old_color_space: if color_space_pil != old_color_space:
raise pytest.UsageError( raise pytest.UsageError(
f"Converting the tensor image into an PIL image changed the colorspace " f"Converting the tensor image into an PIL image changed the colorspace "
...@@ -715,7 +715,7 @@ def reference_inputs_convert_color_space_image_tensor(): ...@@ -715,7 +715,7 @@ def reference_inputs_convert_color_space_image_tensor():
def sample_inputs_convert_color_space_video(): def sample_inputs_convert_color_space_video():
color_spaces = [features.ColorSpace.GRAY, features.ColorSpace.RGB] color_spaces = [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB]
for old_color_space, new_color_space in cycle_over(color_spaces): for old_color_space, new_color_space in cycle_over(color_spaces):
for video_loader in make_video_loaders(sizes=["random"], color_spaces=[old_color_space], num_frames=["random"]): for video_loader in make_video_loaders(sizes=["random"], color_spaces=[old_color_space], num_frames=["random"]):
...@@ -754,7 +754,7 @@ def reference_inputs_vertical_flip_image_tensor(): ...@@ -754,7 +754,7 @@ def reference_inputs_vertical_flip_image_tensor():
def sample_inputs_vertical_flip_bounding_box(): def sample_inputs_vertical_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders( for bounding_box_loader in make_bounding_box_loaders(
formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
): ):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
...@@ -817,7 +817,7 @@ _ROTATE_ANGLES = [-87, 15, 90] ...@@ -817,7 +817,7 @@ _ROTATE_ANGLES = [-87, 15, 90]
def sample_inputs_rotate_image_tensor(): def sample_inputs_rotate_image_tensor():
make_rotate_image_loaders = functools.partial( make_rotate_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]
) )
for image_loader in make_rotate_image_loaders(): for image_loader in make_rotate_image_loaders():
...@@ -899,7 +899,7 @@ _CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20 ...@@ -899,7 +899,7 @@ _CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20
def sample_inputs_crop_image_tensor(): def sample_inputs_crop_image_tensor():
for image_loader, params in itertools.product( for image_loader, params in itertools.product(
make_image_loaders(sizes=[(16, 17)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]), make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]),
[ [
dict(top=4, left=3, height=7, width=8), dict(top=4, left=3, height=7, width=8),
dict(top=-1, left=3, height=7, width=8), dict(top=-1, left=3, height=7, width=8),
...@@ -1085,7 +1085,7 @@ _PAD_PARAMS = combinations_grid( ...@@ -1085,7 +1085,7 @@ _PAD_PARAMS = combinations_grid(
def sample_inputs_pad_image_tensor(): def sample_inputs_pad_image_tensor():
make_pad_image_loaders = functools.partial( make_pad_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]
) )
for image_loader, padding in itertools.product( for image_loader, padding in itertools.product(
...@@ -1401,7 +1401,7 @@ _CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)] ...@@ -1401,7 +1401,7 @@ _CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
def sample_inputs_center_crop_image_tensor(): def sample_inputs_center_crop_image_tensor():
for image_loader, output_size in itertools.product( for image_loader, output_size in itertools.product(
make_image_loaders(sizes=[(16, 17)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]), make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]),
[ [
# valid `output_size` types for which cropping is applied to both dimensions # valid `output_size` types for which cropping is applied to both dimensions
*[5, (4,), (2, 3), [6], [3, 2]], *[5, (4,), (2, 3), [6], [3, 2]],
...@@ -1488,7 +1488,7 @@ KERNEL_INFOS.extend( ...@@ -1488,7 +1488,7 @@ KERNEL_INFOS.extend(
def sample_inputs_gaussian_blur_image_tensor(): def sample_inputs_gaussian_blur_image_tensor():
make_gaussian_blur_image_loaders = functools.partial( make_gaussian_blur_image_loaders = functools.partial(
make_image_loaders, sizes=[(7, 33)], color_spaces=[features.ColorSpace.RGB] make_image_loaders, sizes=[(7, 33)], color_spaces=[datapoints.ColorSpace.RGB]
) )
for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]): for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
...@@ -1527,7 +1527,7 @@ KERNEL_INFOS.extend( ...@@ -1527,7 +1527,7 @@ KERNEL_INFOS.extend(
def sample_inputs_equalize_image_tensor(): def sample_inputs_equalize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1555,7 +1555,7 @@ def reference_inputs_equalize_image_tensor(): ...@@ -1555,7 +1555,7 @@ def reference_inputs_equalize_image_tensor():
spatial_size = (256, 256) spatial_size = (256, 256)
for dtype, color_space, fn in itertools.product( for dtype, color_space, fn in itertools.product(
[torch.uint8], [torch.uint8],
[features.ColorSpace.GRAY, features.ColorSpace.RGB], [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB],
[ [
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
lambda shape, dtype, device: torch.full( lambda shape, dtype, device: torch.full(
...@@ -1611,14 +1611,14 @@ KERNEL_INFOS.extend( ...@@ -1611,14 +1611,14 @@ KERNEL_INFOS.extend(
def sample_inputs_invert_image_tensor(): def sample_inputs_invert_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
def reference_inputs_invert_image_tensor(): def reference_inputs_invert_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
): ):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1651,7 +1651,7 @@ _POSTERIZE_BITS = [1, 4, 8] ...@@ -1651,7 +1651,7 @@ _POSTERIZE_BITS = [1, 4, 8]
def sample_inputs_posterize_image_tensor(): def sample_inputs_posterize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])
...@@ -1659,7 +1659,7 @@ def sample_inputs_posterize_image_tensor(): ...@@ -1659,7 +1659,7 @@ def sample_inputs_posterize_image_tensor():
def reference_inputs_posterize_image_tensor(): def reference_inputs_posterize_image_tensor():
for image_loader, bits in itertools.product( for image_loader, bits in itertools.product(
make_image_loaders( make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
), ),
_POSTERIZE_BITS, _POSTERIZE_BITS,
): ):
...@@ -1698,14 +1698,14 @@ def _get_solarize_thresholds(dtype): ...@@ -1698,14 +1698,14 @@ def _get_solarize_thresholds(dtype):
def sample_inputs_solarize_image_tensor(): def sample_inputs_solarize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))
def reference_inputs_solarize_image_tensor(): def reference_inputs_solarize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
): ):
for threshold in _get_solarize_thresholds(image_loader.dtype): for threshold in _get_solarize_thresholds(image_loader.dtype):
yield ArgsKwargs(image_loader, threshold=threshold) yield ArgsKwargs(image_loader, threshold=threshold)
...@@ -1741,14 +1741,14 @@ KERNEL_INFOS.extend( ...@@ -1741,14 +1741,14 @@ KERNEL_INFOS.extend(
def sample_inputs_autocontrast_image_tensor(): def sample_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
def reference_inputs_autocontrast_image_tensor(): def reference_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
): ):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1785,7 +1785,7 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5] ...@@ -1785,7 +1785,7 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_sharpness_image_tensor(): def sample_inputs_adjust_sharpness_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random", (2, 2)], sizes=["random", (2, 2)],
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB),
): ):
yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
...@@ -1793,7 +1793,7 @@ def sample_inputs_adjust_sharpness_image_tensor(): ...@@ -1793,7 +1793,7 @@ def sample_inputs_adjust_sharpness_image_tensor():
def reference_inputs_adjust_sharpness_image_tensor(): def reference_inputs_adjust_sharpness_image_tensor():
for image_loader, sharpness_factor in itertools.product( for image_loader, sharpness_factor in itertools.product(
make_image_loaders( make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
), ),
_ADJUST_SHARPNESS_FACTORS, _ADJUST_SHARPNESS_FACTORS,
): ):
...@@ -1859,7 +1859,7 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5] ...@@ -1859,7 +1859,7 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_brightness_image_tensor(): def sample_inputs_adjust_brightness_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])
...@@ -1867,7 +1867,7 @@ def sample_inputs_adjust_brightness_image_tensor(): ...@@ -1867,7 +1867,7 @@ def sample_inputs_adjust_brightness_image_tensor():
def reference_inputs_adjust_brightness_image_tensor(): def reference_inputs_adjust_brightness_image_tensor():
for image_loader, brightness_factor in itertools.product( for image_loader, brightness_factor in itertools.product(
make_image_loaders( make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
), ),
_ADJUST_BRIGHTNESS_FACTORS, _ADJUST_BRIGHTNESS_FACTORS,
): ):
...@@ -1903,7 +1903,7 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5] ...@@ -1903,7 +1903,7 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_contrast_image_tensor(): def sample_inputs_adjust_contrast_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])
...@@ -1911,7 +1911,7 @@ def sample_inputs_adjust_contrast_image_tensor(): ...@@ -1911,7 +1911,7 @@ def sample_inputs_adjust_contrast_image_tensor():
def reference_inputs_adjust_contrast_image_tensor(): def reference_inputs_adjust_contrast_image_tensor():
for image_loader, contrast_factor in itertools.product( for image_loader, contrast_factor in itertools.product(
make_image_loaders( make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
), ),
_ADJUST_CONTRAST_FACTORS, _ADJUST_CONTRAST_FACTORS,
): ):
...@@ -1953,7 +1953,7 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [ ...@@ -1953,7 +1953,7 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [
def sample_inputs_adjust_gamma_image_tensor(): def sample_inputs_adjust_gamma_image_tensor():
gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
...@@ -1961,7 +1961,7 @@ def sample_inputs_adjust_gamma_image_tensor(): ...@@ -1961,7 +1961,7 @@ def sample_inputs_adjust_gamma_image_tensor():
def reference_inputs_adjust_gamma_image_tensor(): def reference_inputs_adjust_gamma_image_tensor():
for image_loader, (gamma, gain) in itertools.product( for image_loader, (gamma, gain) in itertools.product(
make_image_loaders( make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
), ),
_ADJUST_GAMMA_GAMMAS_GAINS, _ADJUST_GAMMA_GAMMAS_GAINS,
): ):
...@@ -2001,7 +2001,7 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5] ...@@ -2001,7 +2001,7 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5]
def sample_inputs_adjust_hue_image_tensor(): def sample_inputs_adjust_hue_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])
...@@ -2009,7 +2009,7 @@ def sample_inputs_adjust_hue_image_tensor(): ...@@ -2009,7 +2009,7 @@ def sample_inputs_adjust_hue_image_tensor():
def reference_inputs_adjust_hue_image_tensor(): def reference_inputs_adjust_hue_image_tensor():
for image_loader, hue_factor in itertools.product( for image_loader, hue_factor in itertools.product(
make_image_loaders( make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
), ),
_ADJUST_HUE_FACTORS, _ADJUST_HUE_FACTORS,
): ):
...@@ -2047,7 +2047,7 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5] ...@@ -2047,7 +2047,7 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_saturation_image_tensor(): def sample_inputs_adjust_saturation_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
): ):
yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])
...@@ -2055,7 +2055,7 @@ def sample_inputs_adjust_saturation_image_tensor(): ...@@ -2055,7 +2055,7 @@ def sample_inputs_adjust_saturation_image_tensor():
def reference_inputs_adjust_saturation_image_tensor(): def reference_inputs_adjust_saturation_image_tensor():
for image_loader, saturation_factor in itertools.product( for image_loader, saturation_factor in itertools.product(
make_image_loaders( make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
), ),
_ADJUST_SATURATION_FACTORS, _ADJUST_SATURATION_FACTORS,
): ):
...@@ -2120,7 +2120,7 @@ def sample_inputs_five_crop_image_tensor(): ...@@ -2120,7 +2120,7 @@ def sample_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES: for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], sizes=[_get_five_ten_crop_spatial_size(size)],
color_spaces=[features.ColorSpace.RGB], color_spaces=[datapoints.ColorSpace.RGB],
dtypes=[torch.float32], dtypes=[torch.float32],
): ):
yield ArgsKwargs(image_loader, size=size) yield ArgsKwargs(image_loader, size=size)
...@@ -2144,7 +2144,7 @@ def sample_inputs_ten_crop_image_tensor(): ...@@ -2144,7 +2144,7 @@ def sample_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], sizes=[_get_five_ten_crop_spatial_size(size)],
color_spaces=[features.ColorSpace.RGB], color_spaces=[datapoints.ColorSpace.RGB],
dtypes=[torch.float32], dtypes=[torch.float32],
): ):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
...@@ -2218,7 +2218,7 @@ _NORMALIZE_MEANS_STDS = [ ...@@ -2218,7 +2218,7 @@ _NORMALIZE_MEANS_STDS = [
def sample_inputs_normalize_image_tensor(): def sample_inputs_normalize_image_tensor():
for image_loader, (mean, std) in itertools.product( for image_loader, (mean, std) in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]), make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]),
_NORMALIZE_MEANS_STDS, _NORMALIZE_MEANS_STDS,
): ):
yield ArgsKwargs(image_loader, mean=mean, std=std) yield ArgsKwargs(image_loader, mean=mean, std=std)
...@@ -2227,7 +2227,7 @@ def sample_inputs_normalize_image_tensor(): ...@@ -2227,7 +2227,7 @@ def sample_inputs_normalize_image_tensor():
def sample_inputs_normalize_video(): def sample_inputs_normalize_video():
mean, std = _NORMALIZE_MEANS_STDS[0] mean, std = _NORMALIZE_MEANS_STDS[0]
for video_loader in make_video_loaders( for video_loader in make_video_loaders(
sizes=["random"], color_spaces=[features.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32] sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32]
): ):
yield ArgsKwargs(video_loader, mean=mean, std=std) yield ArgsKwargs(video_loader, mean=mean, std=std)
...@@ -2260,7 +2260,7 @@ def sample_inputs_convert_dtype_image_tensor(): ...@@ -2260,7 +2260,7 @@ def sample_inputs_convert_dtype_image_tensor():
continue continue
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[input_dtype] sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[input_dtype]
): ):
yield ArgsKwargs(image_loader, dtype=output_dtype) yield ArgsKwargs(image_loader, dtype=output_dtype)
...@@ -2388,7 +2388,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4): ...@@ -2388,7 +2388,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
def reference_inputs_uniform_temporal_subsample_video(): def reference_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], num_frames=[10]): for video_loader in make_video_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=[10]):
for num_samples in range(1, video_loader.shape[-4] + 1): for num_samples in range(1, video_loader.shape[-4] + 1):
yield ArgsKwargs(video_loader, num_samples) yield ArgsKwargs(video_loader, num_samples)
......
import pytest import pytest
import torch import torch
from torchvision.prototype import features from torchvision.prototype import datapoints
def test_isinstance(): def test_isinstance():
assert isinstance( assert isinstance(
features.Label([0, 1, 0], categories=["foo", "bar"]), datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
torch.Tensor, torch.Tensor,
) )
def test_wrapping_no_copy(): def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
assert label.data_ptr() == tensor.data_ptr() assert label.data_ptr() == tensor.data_ptr()
def test_to_wrapping(): def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
label_to = label.to(torch.int32) label_to = label.to(torch.int32)
assert type(label_to) is features.Label assert type(label_to) is datapoints.Label
assert label_to.dtype is torch.int32 assert label_to.dtype is torch.int32
assert label_to.categories is label.categories assert label_to.categories is label.categories
def test_to_feature_reference(): def test_to_feature_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32) label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
tensor_to = tensor.to(label) tensor_to = tensor.to(label)
...@@ -40,31 +40,31 @@ def test_to_feature_reference(): ...@@ -40,31 +40,31 @@ def test_to_feature_reference():
def test_clone_wrapping(): def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
label_clone = label.clone() label_clone = label.clone()
assert type(label_clone) is features.Label assert type(label_clone) is datapoints.Label
assert label_clone.data_ptr() != label.data_ptr() assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories assert label_clone.categories is label.categories
def test_requires_grad__wrapping(): def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32) tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
assert not label.requires_grad assert not label.requires_grad
label_requires_grad = label.requires_grad_(True) label_requires_grad = label.requires_grad_(True)
assert type(label_requires_grad) is features.Label assert type(label_requires_grad) is datapoints.Label
assert label.requires_grad assert label.requires_grad
assert label_requires_grad.requires_grad assert label_requires_grad.requires_grad
def test_other_op_no_wrapping(): def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here # any operation besides .to() and .clone() will do here
output = label * 2 output = label * 2
...@@ -82,32 +82,32 @@ def test_other_op_no_wrapping(): ...@@ -82,32 +82,32 @@ def test_other_op_no_wrapping():
) )
def test_no_tensor_output_op_no_wrapping(op): def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
output = op(label) output = op(label)
assert type(output) is not features.Label assert type(output) is not datapoints.Label
def test_inplace_op_no_wrapping(): def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
output = label.add_(0) output = label.add_(0)
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
assert type(label) is features.Label assert type(label) is datapoints.Label
def test_wrap_like(): def test_wrap_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here # any operation besides .to() and .clone() will do here
output = label * 2 output = label * 2
label_new = features.Label.wrap_like(label, output) label_new = datapoints.Label.wrap_like(label, output)
assert type(label_new) is features.Label assert type(label_new) is datapoints.Label
assert label_new.data_ptr() == output.data_ptr() assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories assert label_new.categories is label.categories
...@@ -6,6 +6,8 @@ from pathlib import Path ...@@ -6,6 +6,8 @@ from pathlib import Path
import pytest import pytest
import torch import torch
import torchvision.prototype.transforms.utils
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -14,7 +16,7 @@ from torch.utils.data.graph_settings import get_all_graph_pipes ...@@ -14,7 +16,7 @@ from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets, features, transforms from torchvision.prototype import datapoints, datasets, transforms
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
...@@ -130,7 +132,11 @@ class TestCommon: ...@@ -130,7 +132,11 @@ class TestCommon:
def test_no_simple_tensors(self, dataset_mock, config): def test_no_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)} simple_tensors = {
key
for key, value in next_consume(iter(dataset)).items()
if torchvision.prototype.transforms.utils.is_simple_tensor(value)
}
if simple_tensors: if simple_tensors:
raise AssertionError( raise AssertionError(
f"The values of key(s) " f"The values of key(s) "
...@@ -258,7 +264,7 @@ class TestUSPS: ...@@ -258,7 +264,7 @@ class TestUSPS:
assert "image" in sample assert "image" in sample
assert "label" in sample assert "label" in sample
assert isinstance(sample["image"], features.Image) assert isinstance(sample["image"], datapoints.Image)
assert isinstance(sample["label"], features.Label) assert isinstance(sample["label"], datapoints.Label)
assert sample["image"].shape == (1, 16, 16) assert sample["image"].shape == (1, 16, 16)
...@@ -6,6 +6,8 @@ import PIL.Image ...@@ -6,6 +6,8 @@ import PIL.Image
import pytest import pytest
import torch import torch
import torchvision.prototype.transforms.utils
from common_utils import assert_equal, cpu_and_gpu from common_utils import assert_equal, cpu_and_gpu
from prototype_common_utils import ( from prototype_common_utils import (
DEFAULT_EXTRA_DIMS, DEFAULT_EXTRA_DIMS,
...@@ -22,7 +24,7 @@ from prototype_common_utils import ( ...@@ -22,7 +24,7 @@ from prototype_common_utils import (
make_videos, make_videos,
) )
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms.utils import check_type from torchvision.prototype.transforms.utils import check_type
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
...@@ -159,8 +161,8 @@ class TestSmoke: ...@@ -159,8 +161,8 @@ class TestSmoke:
itertools.chain.from_iterable( itertools.chain.from_iterable(
fn( fn(
color_spaces=[ color_spaces=[
features.ColorSpace.GRAY, datapoints.ColorSpace.GRAY,
features.ColorSpace.RGB, datapoints.ColorSpace.RGB,
], ],
dtypes=[torch.uint8], dtypes=[torch.uint8],
extra_dims=[(), (4,)], extra_dims=[(), (4,)],
...@@ -190,7 +192,7 @@ class TestSmoke: ...@@ -190,7 +192,7 @@ class TestSmoke:
( (
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable( itertools.chain.from_iterable(
fn(color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]) fn(color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32])
for fn in [ for fn in [
make_images, make_images,
make_vanilla_tensor_images, make_vanilla_tensor_images,
...@@ -237,10 +239,10 @@ class TestSmoke: ...@@ -237,10 +239,10 @@ class TestSmoke:
) )
for old_color_space, new_color_space in itertools.product( for old_color_space, new_color_space in itertools.product(
[ [
features.ColorSpace.GRAY, datapoints.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB, datapoints.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA, datapoints.ColorSpace.RGB_ALPHA,
], ],
repeat=2, repeat=2,
) )
...@@ -251,7 +253,7 @@ class TestSmoke: ...@@ -251,7 +253,7 @@ class TestSmoke:
def test_convert_color_space_unsupported_types(self): def test_convert_color_space_unsupported_types(self):
transform = transforms.ConvertColorSpace( transform = transforms.ConvertColorSpace(
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY color_space=datapoints.ColorSpace.RGB, old_color_space=datapoints.ColorSpace.GRAY
) )
for inpt in [make_bounding_box(format="XYXY"), make_masks()]: for inpt in [make_bounding_box(format="XYXY"), make_masks()]:
...@@ -287,26 +289,26 @@ class TestRandomHorizontalFlip: ...@@ -287,26 +289,26 @@ class TestRandomHorizontalFlip:
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p) transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(features.Image(input)) actual = transform(datapoints.Image(input))
assert_equal(features.Image(expected), actual) assert_equal(datapoints.Image(expected), actual)
def test_features_mask(self, p): def test_features_mask(self, p):
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p) transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(features.Mask(input)) actual = transform(datapoints.Mask(input))
assert_equal(features.Mask(expected), actual) assert_equal(datapoints.Mask(expected), actual)
def test_features_bounding_box(self, p): def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomHorizontalFlip(p=p) transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(input) actual = transform(input)
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
expected = features.BoundingBox.wrap_like(input, expected_image_tensor) expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual) assert_equal(expected, actual)
assert actual.format == expected.format assert actual.format == expected.format
assert actual.spatial_size == expected.spatial_size assert actual.spatial_size == expected.spatial_size
...@@ -340,26 +342,26 @@ class TestRandomVerticalFlip: ...@@ -340,26 +342,26 @@ class TestRandomVerticalFlip:
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p) transform = transforms.RandomVerticalFlip(p=p)
actual = transform(features.Image(input)) actual = transform(datapoints.Image(input))
assert_equal(features.Image(expected), actual) assert_equal(datapoints.Image(expected), actual)
def test_features_mask(self, p): def test_features_mask(self, p):
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p) transform = transforms.RandomVerticalFlip(p=p)
actual = transform(features.Mask(input)) actual = transform(datapoints.Mask(input))
assert_equal(features.Mask(expected), actual) assert_equal(datapoints.Mask(expected), actual)
def test_features_bounding_box(self, p): def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomVerticalFlip(p=p) transform = transforms.RandomVerticalFlip(p=p)
actual = transform(input) actual = transform(input)
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
expected = features.BoundingBox.wrap_like(input, expected_image_tensor) expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual) assert_equal(expected, actual)
assert actual.format == expected.format assert actual.format == expected.format
assert actual.spatial_size == expected.spatial_size assert actual.spatial_size == expected.spatial_size
...@@ -386,7 +388,7 @@ class TestPad: ...@@ -386,7 +388,7 @@ class TestPad:
transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad") fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt) _ = transform(inpt)
fill = transforms._utils._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
...@@ -394,13 +396,13 @@ class TestPad: ...@@ -394,13 +396,13 @@ class TestPad:
padding = list(padding) padding = list(padding)
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) @pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}])
def test__transform_image_mask(self, fill, mocker): def test__transform_image_mask(self, fill, mocker):
transform = transforms.Pad(1, fill=fill, padding_mode="constant") transform = transforms.Pad(1, fill=fill, padding_mode="constant")
fn = mocker.patch("torchvision.prototype.transforms.functional.pad") fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
image = features.Image(torch.rand(3, 32, 32)) image = datapoints.Image(torch.rand(3, 32, 32))
mask = features.Mask(torch.randint(0, 5, size=(32, 32))) mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask] inpt = [image, mask]
_ = transform(inpt) _ = transform(inpt)
...@@ -436,7 +438,7 @@ class TestRandomZoomOut: ...@@ -436,7 +438,7 @@ class TestRandomZoomOut:
def test__get_params(self, fill, side_range, mocker): def test__get_params(self, fill, side_range, mocker):
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) transform = transforms.RandomZoomOut(fill=fill, side_range=side_range)
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=datapoints.Image)
h, w = image.spatial_size = (24, 32) h, w = image.spatial_size = (24, 32)
params = transform._get_params([image]) params = transform._get_params([image])
...@@ -450,7 +452,7 @@ class TestRandomZoomOut: ...@@ -450,7 +452,7 @@ class TestRandomZoomOut:
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
def test__transform(self, fill, side_range, mocker): def test__transform(self, fill, side_range, mocker):
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.spatial_size = (24, 32)
...@@ -469,13 +471,13 @@ class TestRandomZoomOut: ...@@ -469,13 +471,13 @@ class TestRandomZoomOut:
fill = transforms._utils._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill) fn.assert_called_once_with(inpt, **params, fill=fill)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) @pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}])
def test__transform_image_mask(self, fill, mocker): def test__transform_image_mask(self, fill, mocker):
transform = transforms.RandomZoomOut(fill=fill, p=1.0) transform = transforms.RandomZoomOut(fill=fill, p=1.0)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad") fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
image = features.Image(torch.rand(3, 32, 32)) image = datapoints.Image(torch.rand(3, 32, 32))
mask = features.Mask(torch.randint(0, 5, size=(32, 32))) mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask] inpt = [image, mask]
torch.manual_seed(12) torch.manual_seed(12)
...@@ -547,7 +549,7 @@ class TestRandomRotation: ...@@ -547,7 +549,7 @@ class TestRandomRotation:
assert transform.degrees == [float(-degrees), float(degrees)] assert transform.degrees == [float(-degrees), float(degrees)]
fn = mocker.patch("torchvision.prototype.transforms.functional.rotate") fn = mocker.patch("torchvision.prototype.transforms.functional.rotate")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users # This could be also helpful for users
# Otherwise, we can mock transform._get_params # Otherwise, we can mock transform._get_params
...@@ -563,10 +565,10 @@ class TestRandomRotation: ...@@ -563,10 +565,10 @@ class TestRandomRotation:
@pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("expand", [False, True])
def test_boundingbox_spatial_size(self, angle, expand): def test_boundingbox_spatial_size(self, angle, expand):
# Specific test for BoundingBox.rotate # Specific test for BoundingBox.rotate
bbox = features.BoundingBox( bbox = datapoints.BoundingBox(
torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, spatial_size=(32, 32) torch.tensor([1, 2, 3, 4]), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(32, 32)
) )
img = features.Image(torch.rand(1, 3, 32, 32)) img = datapoints.Image(torch.rand(1, 3, 32, 32))
out_img = img.rotate(angle, expand=expand) out_img = img.rotate(angle, expand=expand)
out_bbox = bbox.rotate(angle, expand=expand) out_bbox = bbox.rotate(angle, expand=expand)
...@@ -619,7 +621,7 @@ class TestRandomAffine: ...@@ -619,7 +621,7 @@ class TestRandomAffine:
@pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) @pytest.mark.parametrize("scale", [None, [0.7, 1.2]])
@pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]])
def test__get_params(self, degrees, translate, scale, shear, mocker): def test__get_params(self, degrees, translate, scale, shear, mocker):
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
h, w = image.spatial_size h, w = image.spatial_size
...@@ -682,7 +684,7 @@ class TestRandomAffine: ...@@ -682,7 +684,7 @@ class TestRandomAffine:
assert transform.degrees == [float(-degrees), float(degrees)] assert transform.degrees == [float(-degrees), float(degrees)]
fn = mocker.patch("torchvision.prototype.transforms.functional.affine") fn = mocker.patch("torchvision.prototype.transforms.functional.affine")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.spatial_size = (24, 32)
...@@ -718,7 +720,7 @@ class TestRandomCrop: ...@@ -718,7 +720,7 @@ class TestRandomCrop:
@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
@pytest.mark.parametrize("size, pad_if_needed", [((10, 10), False), ((50, 25), True)]) @pytest.mark.parametrize("size, pad_if_needed", [((10, 10), False), ((50, 25), True)])
def test__get_params(self, padding, pad_if_needed, size, mocker): def test__get_params(self, padding, pad_if_needed, size, mocker):
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
h, w = image.spatial_size h, w = image.spatial_size
...@@ -771,11 +773,11 @@ class TestRandomCrop: ...@@ -771,11 +773,11 @@ class TestRandomCrop:
output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode
) )
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (32, 32) inpt.spatial_size = (32, 32)
expected = mocker.MagicMock(spec=features.Image) expected = mocker.MagicMock(spec=datapoints.Image)
expected.num_channels = 3 expected.num_channels = 3
if isinstance(padding, int): if isinstance(padding, int):
expected.spatial_size = (inpt.spatial_size[0] + padding, inpt.spatial_size[1] + padding) expected.spatial_size = (inpt.spatial_size[0] + padding, inpt.spatial_size[1] + padding)
...@@ -859,7 +861,7 @@ class TestGaussianBlur: ...@@ -859,7 +861,7 @@ class TestGaussianBlur:
assert transform.sigma == [sigma, sigma] assert transform.sigma == [sigma, sigma]
fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.spatial_size = (24, 32)
...@@ -891,7 +893,7 @@ class TestRandomColorOp: ...@@ -891,7 +893,7 @@ class TestRandomColorOp:
transform = transform_cls(p=p, **kwargs) transform = transform_cls(p=p, **kwargs)
fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}") fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt) _ = transform(inpt)
if p > 0.0: if p > 0.0:
fn.assert_called_once_with(inpt, **kwargs) fn.assert_called_once_with(inpt, **kwargs)
...@@ -910,7 +912,7 @@ class TestRandomPerspective: ...@@ -910,7 +912,7 @@ class TestRandomPerspective:
def test__get_params(self, mocker): def test__get_params(self, mocker):
dscale = 0.5 dscale = 0.5
transform = transforms.RandomPerspective(dscale) transform = transforms.RandomPerspective(dscale)
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
...@@ -927,7 +929,7 @@ class TestRandomPerspective: ...@@ -927,7 +929,7 @@ class TestRandomPerspective:
transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation)
fn = mocker.patch("torchvision.prototype.transforms.functional.perspective") fn = mocker.patch("torchvision.prototype.transforms.functional.perspective")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.spatial_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
...@@ -971,7 +973,7 @@ class TestElasticTransform: ...@@ -971,7 +973,7 @@ class TestElasticTransform:
alpha = 2.0 alpha = 2.0
sigma = 3.0 sigma = 3.0
transform = transforms.ElasticTransform(alpha, sigma) transform = transforms.ElasticTransform(alpha, sigma)
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
...@@ -1001,7 +1003,7 @@ class TestElasticTransform: ...@@ -1001,7 +1003,7 @@ class TestElasticTransform:
assert transform.sigma == sigma assert transform.sigma == sigma
fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") fn = mocker.patch("torchvision.prototype.transforms.functional.elastic")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.spatial_size = (24, 32)
...@@ -1030,7 +1032,7 @@ class TestRandomErasing: ...@@ -1030,7 +1032,7 @@ class TestRandomErasing:
with pytest.raises(ValueError, match="Scale should be between 0 and 1"): with pytest.raises(ValueError, match="Scale should be between 0 and 1"):
transforms.RandomErasing(scale=[-1, 2]) transforms.RandomErasing(scale=[-1, 2])
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
...@@ -1041,7 +1043,7 @@ class TestRandomErasing: ...@@ -1041,7 +1043,7 @@ class TestRandomErasing:
@pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"]) @pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"])
def test__get_params(self, value, mocker): def test__get_params(self, value, mocker):
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
...@@ -1100,7 +1102,7 @@ class TestRandomErasing: ...@@ -1100,7 +1102,7 @@ class TestRandomErasing:
class TestTransform: class TestTransform:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int],
) )
def test_check_transformed_types(self, inpt_type, mocker): def test_check_transformed_types(self, inpt_type, mocker):
# This test ensures that we correctly handle which types to transform and which to bypass # This test ensures that we correctly handle which types to transform and which to bypass
...@@ -1118,7 +1120,7 @@ class TestTransform: ...@@ -1118,7 +1120,7 @@ class TestTransform:
class TestToImageTensor: class TestToImageTensor:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch( fn = mocker.patch(
...@@ -1129,7 +1131,7 @@ class TestToImageTensor: ...@@ -1129,7 +1131,7 @@ class TestToImageTensor:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImageTensor() transform = transforms.ToImageTensor()
transform(inpt) transform(inpt)
if inpt_type in (features.BoundingBox, features.Image, str, int): if inpt_type in (datapoints.BoundingBox, datapoints.Image, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
...@@ -1138,7 +1140,7 @@ class TestToImageTensor: ...@@ -1138,7 +1140,7 @@ class TestToImageTensor:
class TestToImagePIL: class TestToImagePIL:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")
...@@ -1146,7 +1148,7 @@ class TestToImagePIL: ...@@ -1146,7 +1148,7 @@ class TestToImagePIL:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL() transform = transforms.ToImagePIL()
transform(inpt) transform(inpt)
if inpt_type in (features.BoundingBox, PIL.Image.Image, str, int): if inpt_type in (datapoints.BoundingBox, PIL.Image.Image, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt, mode=transform.mode) fn.assert_called_once_with(inpt, mode=transform.mode)
...@@ -1155,7 +1157,7 @@ class TestToImagePIL: ...@@ -1155,7 +1157,7 @@ class TestToImagePIL:
class TestToPILImage: class TestToPILImage:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")
...@@ -1163,7 +1165,7 @@ class TestToPILImage: ...@@ -1163,7 +1165,7 @@ class TestToPILImage:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage() transform = transforms.ToPILImage()
transform(inpt) transform(inpt)
if inpt_type in (PIL.Image.Image, features.BoundingBox, str, int): if inpt_type in (PIL.Image.Image, datapoints.BoundingBox, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt, mode=transform.mode) fn.assert_called_once_with(inpt, mode=transform.mode)
...@@ -1172,7 +1174,7 @@ class TestToPILImage: ...@@ -1172,7 +1174,7 @@ class TestToPILImage:
class TestToTensor: class TestToTensor:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, features.Image, np.ndarray, features.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_tensor") fn = mocker.patch("torchvision.transforms.functional.to_tensor")
...@@ -1181,7 +1183,7 @@ class TestToTensor: ...@@ -1181,7 +1183,7 @@ class TestToTensor:
with pytest.warns(UserWarning, match="deprecated and will be removed"): with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor() transform = transforms.ToTensor()
transform(inpt) transform(inpt)
if inpt_type in (features.Image, torch.Tensor, features.BoundingBox, str, int): if inpt_type in (datapoints.Image, torch.Tensor, datapoints.BoundingBox, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
...@@ -1223,10 +1225,10 @@ class TestRandomIoUCrop: ...@@ -1223,10 +1225,10 @@ class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
def test__get_params(self, device, options, mocker): def test__get_params(self, device, options, mocker):
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=datapoints.Image)
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
bboxes = features.BoundingBox( bboxes = datapoints.BoundingBox(
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
format="XYXY", format="XYXY",
spatial_size=image.spatial_size, spatial_size=image.spatial_size,
...@@ -1263,9 +1265,9 @@ class TestRandomIoUCrop: ...@@ -1263,9 +1265,9 @@ class TestRandomIoUCrop:
def test__transform_empty_params(self, mocker): def test__transform_empty_params(self, mocker):
transform = transforms.RandomIoUCrop(sampler_options=[2.0]) transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = features.Image(torch.rand(1, 3, 4, 4)) image = datapoints.Image(torch.rand(1, 3, 4, 4))
bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4)) bboxes = datapoints.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4))
label = features.Label(torch.tensor([1])) label = datapoints.Label(torch.tensor([1]))
sample = [image, bboxes, label] sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output: # Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock(return_value={}) transform._get_params = mocker.MagicMock(return_value={})
...@@ -1283,10 +1285,10 @@ class TestRandomIoUCrop: ...@@ -1283,10 +1285,10 @@ class TestRandomIoUCrop:
def test__transform(self, mocker): def test__transform(self, mocker):
transform = transforms.RandomIoUCrop() transform = transforms.RandomIoUCrop()
image = features.Image(torch.rand(3, 32, 24)) image = datapoints.Image(torch.rand(3, 32, 24))
bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,)) bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,))
label = features.Label(torch.randint(0, 10, size=(6,))) label = datapoints.Label(torch.randint(0, 10, size=(6,)))
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) ohe_label = datapoints.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
masks = make_detection_mask((32, 24), num_objects=6) masks = make_detection_mask((32, 24), num_objects=6)
sample = [image, bboxes, label, ohe_label, masks] sample = [image, bboxes, label, ohe_label, masks]
...@@ -1312,21 +1314,21 @@ class TestRandomIoUCrop: ...@@ -1312,21 +1314,21 @@ class TestRandomIoUCrop:
# check number of bboxes vs number of labels: # check number of bboxes vs number of labels:
output_bboxes = output[1] output_bboxes = output[1]
assert isinstance(output_bboxes, features.BoundingBox) assert isinstance(output_bboxes, datapoints.BoundingBox)
assert len(output_bboxes) == expected_within_targets assert len(output_bboxes) == expected_within_targets
# check labels # check labels
output_label = output[2] output_label = output[2]
assert isinstance(output_label, features.Label) assert isinstance(output_label, datapoints.Label)
assert len(output_label) == expected_within_targets assert len(output_label) == expected_within_targets
torch.testing.assert_close(output_label, label[is_within_crop_area]) torch.testing.assert_close(output_label, label[is_within_crop_area])
output_ohe_label = output[3] output_ohe_label = output[3]
assert isinstance(output_ohe_label, features.OneHotLabel) assert isinstance(output_ohe_label, datapoints.OneHotLabel)
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])
output_masks = output[4] output_masks = output[4]
assert isinstance(output_masks, features.Mask) assert isinstance(output_masks, datapoints.Mask)
assert len(output_masks) == expected_within_targets assert len(output_masks) == expected_within_targets
...@@ -1337,7 +1339,7 @@ class TestScaleJitter: ...@@ -1337,7 +1339,7 @@ class TestScaleJitter:
scale_range = (0.5, 1.5) scale_range = (0.5, 1.5)
transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range) transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size) sample = mocker.MagicMock(spec=datapoints.Image, num_channels=3, spatial_size=spatial_size)
n_samples = 5 n_samples = 5
for _ in range(n_samples): for _ in range(n_samples):
...@@ -1387,7 +1389,7 @@ class TestRandomShortestSize: ...@@ -1387,7 +1389,7 @@ class TestRandomShortestSize:
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size) sample = mocker.MagicMock(spec=datapoints.Image, num_channels=3, spatial_size=spatial_size)
params = transform._get_params([sample]) params = transform._get_params([sample])
assert "size" in params assert "size" in params
...@@ -1439,21 +1441,21 @@ class TestSimpleCopyPaste: ...@@ -1439,21 +1441,21 @@ class TestSimpleCopyPaste:
flat_sample = [ flat_sample = [
# images, batch size = 2 # images, batch size = 2
self.create_fake_image(mocker, features.Image), self.create_fake_image(mocker, datapoints.Image),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=features.Label), mocker.MagicMock(spec=datapoints.Label),
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=datapoints.BoundingBox),
mocker.MagicMock(spec=features.Mask), mocker.MagicMock(spec=datapoints.Mask),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=datapoints.BoundingBox),
mocker.MagicMock(spec=features.Mask), mocker.MagicMock(spec=datapoints.Mask),
] ]
with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"): with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"):
transform._extract_image_targets(flat_sample) transform._extract_image_targets(flat_sample)
@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor]) @pytest.mark.parametrize("image_type", [datapoints.Image, PIL.Image.Image, torch.Tensor])
@pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel]) @pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel])
def test__extract_image_targets(self, image_type, label_type, mocker): def test__extract_image_targets(self, image_type, label_type, mocker):
transform = transforms.SimpleCopyPaste() transform = transforms.SimpleCopyPaste()
...@@ -1463,12 +1465,12 @@ class TestSimpleCopyPaste: ...@@ -1463,12 +1465,12 @@ class TestSimpleCopyPaste:
self.create_fake_image(mocker, image_type), self.create_fake_image(mocker, image_type),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=label_type), mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=datapoints.BoundingBox),
mocker.MagicMock(spec=features.Mask), mocker.MagicMock(spec=datapoints.Mask),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=label_type), mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=datapoints.BoundingBox),
mocker.MagicMock(spec=features.Mask), mocker.MagicMock(spec=datapoints.Mask),
] ]
images, targets = transform._extract_image_targets(flat_sample) images, targets = transform._extract_image_targets(flat_sample)
...@@ -1483,15 +1485,15 @@ class TestSimpleCopyPaste: ...@@ -1483,15 +1485,15 @@ class TestSimpleCopyPaste:
for target in targets: for target in targets:
for key, type_ in [ for key, type_ in [
("boxes", features.BoundingBox), ("boxes", datapoints.BoundingBox),
("masks", features.Mask), ("masks", datapoints.Mask),
("labels", label_type), ("labels", label_type),
]: ]:
assert key in target assert key in target
assert isinstance(target[key], type_) assert isinstance(target[key], type_)
assert target[key] in flat_sample assert target[key] in flat_sample
@pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel]) @pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel])
def test__copy_paste(self, label_type): def test__copy_paste(self, label_type):
image = 2 * torch.ones(3, 32, 32) image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32) masks = torch.zeros(2, 32, 32)
...@@ -1501,13 +1503,13 @@ class TestSimpleCopyPaste: ...@@ -1501,13 +1503,13 @@ class TestSimpleCopyPaste:
blending = True blending = True
resize_interpolation = InterpolationMode.BILINEAR resize_interpolation = InterpolationMode.BILINEAR
antialias = None antialias = None
if label_type == features.OneHotLabel: if label_type == datapoints.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5) labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = { target = {
"boxes": features.BoundingBox( "boxes": datapoints.BoundingBox(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32) torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32)
), ),
"masks": features.Mask(masks), "masks": datapoints.Mask(masks),
"labels": label_type(labels), "labels": label_type(labels),
} }
...@@ -1516,13 +1518,13 @@ class TestSimpleCopyPaste: ...@@ -1516,13 +1518,13 @@ class TestSimpleCopyPaste:
paste_masks[0, 13:19, 12:18] = 1 paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1 paste_masks[1, 15:19, 1:8] = 1
paste_labels = torch.tensor([3, 4]) paste_labels = torch.tensor([3, 4])
if label_type == features.OneHotLabel: if label_type == datapoints.OneHotLabel:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = { paste_target = {
"boxes": features.BoundingBox( "boxes": datapoints.BoundingBox(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32) torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32)
), ),
"masks": features.Mask(paste_masks), "masks": datapoints.Mask(paste_masks),
"labels": label_type(paste_labels), "labels": label_type(paste_labels),
} }
...@@ -1538,7 +1540,7 @@ class TestSimpleCopyPaste: ...@@ -1538,7 +1540,7 @@ class TestSimpleCopyPaste:
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
expected_labels = torch.tensor([1, 2, 3, 4]) expected_labels = torch.tensor([1, 2, 3, 4])
if label_type == features.OneHotLabel: if label_type == datapoints.OneHotLabel:
expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5) expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5)
torch.testing.assert_close(output_target["labels"], label_type(expected_labels)) torch.testing.assert_close(output_target["labels"], label_type(expected_labels))
...@@ -1556,9 +1558,9 @@ class TestFixedSizeCrop: ...@@ -1556,9 +1558,9 @@ class TestFixedSizeCrop:
transform = transforms.FixedSizeCrop(size=crop_size) transform = transforms.FixedSizeCrop(size=crop_size)
flat_inputs = [ flat_inputs = [
make_image(size=spatial_size, color_space=features.ColorSpace.RGB), make_image(size=spatial_size, color_space=datapoints.ColorSpace.RGB),
make_bounding_box( make_bounding_box(
format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape
), ),
] ]
params = transform._get_params(flat_inputs) params = transform._get_params(flat_inputs)
...@@ -1656,7 +1658,7 @@ class TestFixedSizeCrop: ...@@ -1656,7 +1658,7 @@ class TestFixedSizeCrop:
) )
bounding_boxes = make_bounding_box( bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,)
) )
masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,)) masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,)) labels = make_label(extra_dims=(batch_size,))
...@@ -1695,7 +1697,7 @@ class TestFixedSizeCrop: ...@@ -1695,7 +1697,7 @@ class TestFixedSizeCrop:
) )
bounding_box = make_bounding_box( bounding_box = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,) format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,)
) )
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")
...@@ -1721,7 +1723,7 @@ class TestLinearTransformation: ...@@ -1721,7 +1723,7 @@ class TestLinearTransformation:
[ [
122 * torch.ones(1, 3, 8, 8), 122 * torch.ones(1, 3, 8, 8),
122.0 * torch.ones(1, 3, 8, 8), 122.0 * torch.ones(1, 3, 8, 8),
features.Image(122 * torch.ones(1, 3, 8, 8)), datapoints.Image(122 * torch.ones(1, 3, 8, 8)),
PIL.Image.new("RGB", (8, 8), (122, 122, 122)), PIL.Image.new("RGB", (8, 8), (122, 122, 122)),
], ],
) )
...@@ -1744,10 +1746,10 @@ class TestLinearTransformation: ...@@ -1744,10 +1746,10 @@ class TestLinearTransformation:
class TestLabelToOneHot: class TestLabelToOneHot:
def test__transform(self): def test__transform(self):
categories = ["apple", "pear", "pineapple"] categories = ["apple", "pear", "pineapple"]
labels = features.Label(torch.tensor([0, 1, 2, 1]), categories=categories) labels = datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories)
transform = transforms.LabelToOneHot() transform = transforms.LabelToOneHot()
ohe_labels = transform(labels) ohe_labels = transform(labels)
assert isinstance(ohe_labels, features.OneHotLabel) assert isinstance(ohe_labels, datapoints.OneHotLabel)
assert ohe_labels.shape == (4, 3) assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories assert ohe_labels.categories == labels.categories == categories
...@@ -1797,11 +1799,11 @@ class TestRandomResize: ...@@ -1797,11 +1799,11 @@ class TestRandomResize:
[ [
( (
torch.float64, torch.float64,
{torch.Tensor: torch.float64, features.Image: torch.float64, features.BoundingBox: torch.float64}, {torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64},
), ),
( (
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64}, {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64}, {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
), ),
], ],
) )
...@@ -1809,7 +1811,7 @@ def test_to_dtype(dtype, expected_dtypes): ...@@ -1809,7 +1811,7 @@ def test_to_dtype(dtype, expected_dtypes):
sample = dict( sample = dict(
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"), plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
image=make_image(dtype=torch.uint8), image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY, dtype=torch.float32), bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str", str="str",
int=0, int=0,
) )
...@@ -1834,12 +1836,12 @@ def test_to_dtype(dtype, expected_dtypes): ...@@ -1834,12 +1836,12 @@ def test_to_dtype(dtype, expected_dtypes):
("dims", "inverse_dims"), ("dims", "inverse_dims"),
[ [
( (
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: None}, {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None},
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: None}, {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None},
), ),
( (
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: (1, 2, 3, 0)}, {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: (3, 0, 1, 2)}, {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
), ),
], ],
) )
...@@ -1847,7 +1849,7 @@ def test_permute_dimensions(dims, inverse_dims): ...@@ -1847,7 +1849,7 @@ def test_permute_dimensions(dims, inverse_dims):
sample = dict( sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(), image=make_image(),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY), bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(), video=make_video(),
str="str", str="str",
int=0, int=0,
...@@ -1860,7 +1862,9 @@ def test_permute_dimensions(dims, inverse_dims): ...@@ -1860,7 +1862,9 @@ def test_permute_dimensions(dims, inverse_dims):
value_type = type(value) value_type = type(value)
transformed_value = transformed_sample[key] transformed_value = transformed_sample[key]
if check_type(value, (features.Image, features.is_simple_tensor, features.Video)): if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transform.dims.get(value_type) is not None: if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value) assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor assert type(transformed_value) == torch.Tensor
...@@ -1872,14 +1876,14 @@ def test_permute_dimensions(dims, inverse_dims): ...@@ -1872,14 +1876,14 @@ def test_permute_dimensions(dims, inverse_dims):
"dims", "dims",
[ [
(-1, -2), (-1, -2),
{torch.Tensor: (-1, -2), features.Image: (1, 2), features.Video: None}, {torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None},
], ],
) )
def test_transpose_dimensions(dims): def test_transpose_dimensions(dims):
sample = dict( sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(), image=make_image(),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY), bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(), video=make_video(),
str="str", str="str",
int=0, int=0,
...@@ -1893,7 +1897,9 @@ def test_transpose_dimensions(dims): ...@@ -1893,7 +1897,9 @@ def test_transpose_dimensions(dims):
transformed_value = transformed_sample[key] transformed_value = transformed_sample[key]
transposed_dims = transform.dims.get(value_type) transposed_dims = transform.dims.get(value_type)
if check_type(value, (features.Image, features.is_simple_tensor, features.Video)): if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transposed_dims is not None: if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value) assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor assert type(transformed_value) == torch.Tensor
...@@ -1907,7 +1913,7 @@ class TestUniformTemporalSubsample: ...@@ -1907,7 +1913,7 @@ class TestUniformTemporalSubsample:
[ [
torch.zeros(10, 3, 8, 8), torch.zeros(10, 3, 8, 8),
torch.zeros(1, 10, 3, 8, 8), torch.zeros(1, 10, 3, 8, 8),
features.Video(torch.zeros(1, 10, 3, 8, 8)), datapoints.Video(torch.zeros(1, 10, 3, 8, 8)),
], ],
) )
def test__transform(self, inpt): def test__transform(self, inpt):
......
...@@ -24,13 +24,13 @@ from prototype_common_utils import ( ...@@ -24,13 +24,13 @@ from prototype_common_utils import (
) )
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import datapoints, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as prototype_F from torchvision.prototype.transforms import functional as prototype_F
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import query_spatial_size from torchvision.prototype.transforms.utils import query_spatial_size
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[datapoints.ColorSpace.RGB], extra_dims=[(4,)])
class ConsistencyConfig: class ConsistencyConfig:
...@@ -138,7 +138,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -138,7 +138,7 @@ CONSISTENCY_CONFIGS = [
# Make sure that the product of the height, width and number of channels matches the number of elements in # Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict( make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[features.ColorSpace.RGB] DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[datapoints.ColorSpace.RGB]
), ),
supports_pil=False, supports_pil=False,
), ),
...@@ -150,7 +150,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -150,7 +150,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(num_output_channels=3), ArgsKwargs(num_output_channels=3),
], ],
make_images_kwargs=dict( make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[features.ColorSpace.RGB, features.ColorSpace.GRAY] DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[datapoints.ColorSpace.RGB, datapoints.ColorSpace.GRAY]
), ),
), ),
ConsistencyConfig( ConsistencyConfig(
...@@ -173,10 +173,10 @@ CONSISTENCY_CONFIGS = [ ...@@ -173,10 +173,10 @@ CONSISTENCY_CONFIGS = [
[ArgsKwargs()], [ArgsKwargs()],
make_images_kwargs=dict( make_images_kwargs=dict(
color_spaces=[ color_spaces=[
features.ColorSpace.GRAY, datapoints.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.GRAY_ALPHA,
features.ColorSpace.RGB, datapoints.ColorSpace.RGB,
features.ColorSpace.RGB_ALPHA, datapoints.ColorSpace.RGB_ALPHA,
], ],
extra_dims=[()], extra_dims=[()],
), ),
...@@ -733,7 +733,7 @@ class TestAATransforms: ...@@ -733,7 +733,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -771,7 +771,7 @@ class TestAATransforms: ...@@ -771,7 +771,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -819,7 +819,7 @@ class TestAATransforms: ...@@ -819,7 +819,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -868,7 +868,7 @@ class TestAATransforms: ...@@ -868,7 +868,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -902,7 +902,7 @@ class TestRefDetTransforms: ...@@ -902,7 +902,7 @@ class TestRefDetTransforms:
size = (600, 800) size = (600, 800)
num_objects = 22 num_objects = 22
pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB)) pil_image = to_image_pil(make_image(size=size, color_space=datapoints.ColorSpace.RGB))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -912,7 +912,7 @@ class TestRefDetTransforms: ...@@ -912,7 +912,7 @@ class TestRefDetTransforms:
yield (pil_image, target) yield (pil_image, target)
tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB)) tensor_image = torch.Tensor(make_image(size=size, color_space=datapoints.ColorSpace.RGB))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -922,7 +922,7 @@ class TestRefDetTransforms: ...@@ -922,7 +922,7 @@ class TestRefDetTransforms:
yield (tensor_image, target) yield (tensor_image, target)
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB) feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB)
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -1006,7 +1006,7 @@ class TestRefSegTransforms: ...@@ -1006,7 +1006,7 @@ class TestRefSegTransforms:
conv_fns.extend([torch.Tensor, lambda x: x]) conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns: for conv_fn in conv_fns:
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype) feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype)
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(feature_image), feature_mask) dp = (conv_fn(feature_image), feature_mask)
...@@ -1053,7 +1053,7 @@ class TestRefSegTransforms: ...@@ -1053,7 +1053,7 @@ class TestRefSegTransforms:
seg_transforms.RandomCrop(size=480), seg_transforms.RandomCrop(size=480),
prototype_transforms.Compose( prototype_transforms.Compose(
[ [
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})), PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
prototype_transforms.RandomCrop(size=480), prototype_transforms.RandomCrop(size=480),
] ]
), ),
......
...@@ -10,12 +10,14 @@ import PIL.Image ...@@ -10,12 +10,14 @@ import PIL.Image
import pytest import pytest
import torch import torch
import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box
...@@ -147,18 +149,22 @@ class TestKernels: ...@@ -147,18 +149,22 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device): def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device) (batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input) feature_type = (
datapoints.Image
if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input)
else type(batched_input)
)
# This dictionary contains the number of rightmost dimensions that contain the actual data. # This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension. # Everything to the left is considered a batch dimension.
data_dims = { data_dims = {
features.Image: 3, datapoints.Image: 3,
features.BoundingBox: 1, datapoints.BoundingBox: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground. # common ground.
features.Mask: 2, datapoints.Mask: 2,
features.Video: 4, datapoints.Video: 4,
}.get(feature_type) }.get(feature_type)
if data_dims is None: if data_dims is None:
raise pytest.UsageError( raise pytest.UsageError(
...@@ -281,8 +287,8 @@ def spy_on(mocker): ...@@ -281,8 +287,8 @@ def spy_on(mocker):
class TestDispatchers: class TestDispatchers:
image_sample_inputs = make_info_args_kwargs_parametrization( image_sample_inputs = make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if features.Image in info.kernels], [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
args_kwargs_fn=lambda info: info.sample_inputs(features.Image), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
) )
@ignore_jit_warning_no_profile @ignore_jit_warning_no_profile
...@@ -323,7 +329,7 @@ class TestDispatchers: ...@@ -323,7 +329,7 @@ class TestDispatchers:
(image_feature, *other_args), kwargs = args_kwargs.load() (image_feature, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_feature) image_simple_tensor = torch.Tensor(image_feature)
kernel_info = info.kernel_infos[features.Image] kernel_info = info.kernel_infos[datapoints.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id) spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
info.dispatcher(image_simple_tensor, *other_args, **kwargs) info.dispatcher(image_simple_tensor, *other_args, **kwargs)
...@@ -332,7 +338,7 @@ class TestDispatchers: ...@@ -332,7 +338,7 @@ class TestDispatchers:
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(features.Image), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
) )
def test_dispatch_pil(self, info, args_kwargs, spy_on): def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load() (image_feature, *other_args), kwargs = args_kwargs.load()
...@@ -403,7 +409,7 @@ class TestDispatchers: ...@@ -403,7 +409,7 @@ class TestDispatchers:
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_feature_signatures_consistency(self, info): def test_dispatcher_feature_signatures_consistency(self, info):
try: try:
feature_method = getattr(features._Feature, info.id) feature_method = getattr(datapoints._datapoint.Datapoint, info.id)
except AttributeError: except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.") pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.")
...@@ -413,7 +419,7 @@ class TestDispatchers: ...@@ -413,7 +419,7 @@ class TestDispatchers:
feature_signature = inspect.signature(feature_method) feature_signature = inspect.signature(feature_method)
feature_params = list(feature_signature.parameters.values())[1:] feature_params = list(feature_signature.parameters.values())[1:]
# Because we use `from __future__ import annotations` inside the module where `features._Feature` is defined, # Because we use `from __future__ import annotations` inside the module where `features._datapoint` is defined,
# the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively # the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively
# concrete dispatcher annotations. # concrete dispatcher annotations.
feature_annotations = get_type_hints(feature_method) feature_annotations = get_type_hints(feature_method)
...@@ -505,8 +511,12 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -505,8 +511,12 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
[1, 1, 5, 5], [1, 1, 5, 5],
] ]
in_boxes = features.BoundingBox( in_boxes = datapoints.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, dtype=torch.float64, device=device in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
) )
# Tested parameters # Tested parameters
angle = 63 angle = 63
...@@ -572,7 +582,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -572,7 +582,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
height, width = bbox.spatial_size height, width = bbox.spatial_size
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
) )
points = np.array( points = np.array(
[ [
...@@ -605,15 +615,15 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -605,15 +615,15 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
height = int(height - 2 * tr_y) height = int(height - 2 * tr_y)
width = int(width - 2 * tr_x) width = int(width - 2 * tr_x)
out_bbox = features.BoundingBox( out_bbox = datapoints.BoundingBox(
out_bbox, out_bbox,
format=features.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(height, width), spatial_size=(height, width),
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return ( return (
convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format), convert_format_bounding_box(out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format),
(height, width), (height, width),
) )
...@@ -641,7 +651,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -641,7 +651,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bbox, expected_spatial_size = _compute_expected_bbox(bbox, -angle, expand, center_) expected_bbox, expected_spatial_size = _compute_expected_bbox(bbox, -angle, expand, center_)
expected_bboxes.append(expected_bbox) expected_bboxes.append(expected_bbox)
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -664,8 +674,12 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -664,8 +674,12 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
[spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2], [spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2],
[spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
] ]
in_boxes = features.BoundingBox( in_boxes = datapoints.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, dtype=torch.float64, device=device in_boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
dtype=torch.float64,
device=device,
) )
# Tested parameters # Tested parameters
angle = 45 angle = 45
...@@ -725,7 +739,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device): ...@@ -725,7 +739,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", "format",
[features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"top, left, height, width, expected_bboxes", "top, left, height, width, expected_bboxes",
...@@ -755,9 +769,11 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -755,9 +769,11 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
[50.0, 5.0, 70.0, 22.0], [50.0, 5.0, 70.0, 22.0],
[45.0, 46.0, 56.0, 62.0], [45.0, 46.0, 56.0, 62.0],
] ]
in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=size, device=device) in_boxes = datapoints.BoundingBox(
if format != features.BoundingBoxFormat.XYXY: in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=size, device=device
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) )
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
output_boxes, output_spatial_size = F.crop_bounding_box( output_boxes, output_spatial_size = F.crop_bounding_box(
in_boxes, in_boxes,
...@@ -768,8 +784,8 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -768,8 +784,8 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
size[1], size[1],
) )
if format != features.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_spatial_size, size) torch.testing.assert_close(output_spatial_size, size)
...@@ -802,7 +818,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): ...@@ -802,7 +818,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", "format",
[features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH], [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"top, left, height, width, size", "top, left, height, width, size",
...@@ -831,16 +847,16 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -831,16 +847,16 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size)) expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
expected_bboxes = torch.tensor(expected_bboxes, device=device) expected_bboxes = torch.tensor(expected_bboxes, device=device)
in_boxes = features.BoundingBox( in_boxes = datapoints.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device
) )
if format != features.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
output_boxes, output_spatial_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size) output_boxes, output_spatial_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
if format != features.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_spatial_size, size) torch.testing.assert_close(output_spatial_size, size)
...@@ -868,14 +884,14 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -868,14 +884,14 @@ def test_correctness_pad_bounding_box(device, padding):
bbox_dtype = bbox.dtype bbox_dtype = bbox.dtype
bbox = ( bbox = (
bbox.clone() bbox.clone()
if bbox_format == features.BoundingBoxFormat.XYXY if bbox_format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bbox, bbox_format, features.BoundingBoxFormat.XYXY) else convert_format_bounding_box(bbox, bbox_format, datapoints.BoundingBoxFormat.XYXY)
) )
bbox[0::2] += pad_left bbox[0::2] += pad_left
bbox[1::2] += pad_up bbox[1::2] += pad_up
bbox = convert_format_bounding_box(bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format) bbox = convert_format_bounding_box(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format)
if bbox.dtype != bbox_dtype: if bbox.dtype != bbox_dtype:
# Temporary cast to original dtype # Temporary cast to original dtype
# e.g. float32 -> int # e.g. float32 -> int
...@@ -903,7 +919,7 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -903,7 +919,7 @@ def test_correctness_pad_bounding_box(device, padding):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, padding)) expected_bboxes.append(_compute_expected_bbox(bbox, padding))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -949,7 +965,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -949,7 +965,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
) )
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY bbox, old_format=bbox.format, new_format=datapoints.BoundingBoxFormat.XYXY
) )
points = np.array( points = np.array(
[ [
...@@ -968,14 +984,16 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -968,14 +984,16 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
np.max(transformed_points[:, 0]), np.max(transformed_points[:, 0]),
np.max(transformed_points[:, 1]), np.max(transformed_points[:, 1]),
] ]
out_bbox = features.BoundingBox( out_bbox = datapoints.BoundingBox(
np.array(out_bbox), np.array(out_bbox),
format=features.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=bbox.spatial_size, spatial_size=bbox.spatial_size,
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format) return convert_format_bounding_box(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox.format
)
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -1000,7 +1018,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1000,7 +1018,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
...@@ -1019,7 +1037,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1019,7 +1037,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
format_ = bbox.format format_ = bbox.format
spatial_size_ = bbox.spatial_size spatial_size_ = bbox.spatial_size
dtype = bbox.dtype dtype = bbox.dtype
bbox = convert_format_bounding_box(bbox.float(), format_, features.BoundingBoxFormat.XYWH) bbox = convert_format_bounding_box(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
if len(output_size_) == 1: if len(output_size_) == 1:
output_size_.append(output_size_[-1]) output_size_.append(output_size_[-1])
...@@ -1033,7 +1051,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1033,7 +1051,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bbox[3].item(), bbox[3].item(),
] ]
out_bbox = torch.tensor(out_bbox) out_bbox = torch.tensor(out_bbox)
out_bbox = convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_) out_bbox = convert_format_bounding_box(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
return out_bbox.to(dtype=dtype, device=bbox.device) return out_bbox.to(dtype=dtype, device=bbox.device)
for bboxes in make_bounding_boxes(extra_dims=((4,),)): for bboxes in make_bounding_boxes(extra_dims=((4,),)):
...@@ -1050,7 +1068,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1050,7 +1068,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size) bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) expected_bboxes.append(_compute_expected_bbox(bbox, output_size))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -1135,7 +1153,7 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, ...@@ -1135,7 +1153,7 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize,
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
) )
image = features.Image(tensor) image = datapoints.Image(tensor)
out = fn(image, kernel_size=ksize, sigma=sigma) out = fn(image, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
...@@ -1147,7 +1165,7 @@ def test_normalize_output_type(): ...@@ -1147,7 +1165,7 @@ def test_normalize_output_type():
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output) torch.testing.assert_close(inpt - 0.5, output)
inpt = make_image(color_space=features.ColorSpace.RGB) inpt = make_image(color_space=datapoints.ColorSpace.RGB)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output) torch.testing.assert_close(inpt - 0.5, output)
......
...@@ -3,42 +3,51 @@ import pytest ...@@ -3,42 +3,51 @@ import pytest
import torch import torch
import torchvision.prototype.transforms.utils
from prototype_common_utils import make_bounding_box, make_detection_mask, make_image from prototype_common_utils import make_bounding_box, make_detection_mask, make_image
from torchvision.prototype import features from torchvision.prototype import datapoints
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import has_all, has_any from torchvision.prototype.transforms.utils import has_all, has_any
IMAGE = make_image(color_space=features.ColorSpace.RGB) IMAGE = make_image(color_space=datapoints.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size)
MASK = make_detection_mask(size=IMAGE.spatial_size) MASK = make_detection_mask(size=IMAGE.spatial_size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True),
((MASK,), (features.Image, features.BoundingBox), False), ((MASK,), (datapoints.Image, datapoints.BoundingBox), False),
((BOUNDING_BOX,), (features.Image, features.Mask), False), ((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False),
((IMAGE,), (features.BoundingBox, features.Mask), False), ((IMAGE,), (datapoints.BoundingBox, datapoints.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.Mask), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask),
True, True,
), ),
((), (features.Image, features.BoundingBox, features.Mask), False), ((), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, features.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), (
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), (torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor),
True,
),
(
(to_image_pil(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor),
True,
),
], ],
) )
def test_has_any(sample, types, expected): def test_has_any(sample, types, expected):
...@@ -48,31 +57,31 @@ def test_has_any(sample, types, expected): ...@@ -48,31 +57,31 @@ def test_has_any(sample, types, expected):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBox, datapoints.Mask), True),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.Mask), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask),
True, True,
), ),
((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox), False),
((BOUNDING_BOX, MASK), (features.Image, features.Mask), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False),
((IMAGE, MASK), (features.BoundingBox, features.Mask), False), ((IMAGE, MASK), (datapoints.BoundingBox, datapoints.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.Mask), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask),
True, True,
), ),
((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox, features.Mask), False), ((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
((IMAGE, MASK), (features.Image, features.BoundingBox, features.Mask), False), ((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.Mask), False), ((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBox, datapoints.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.Mask)),), (lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBox, datapoints.Mask)),),
True, True,
), ),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
......
from . import features, models, transforms, utils from . import datapoints, models, transforms, utils
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask from ._mask import Mask
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._feature import _Feature, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
class BoundingBoxFormat(StrEnum): class BoundingBoxFormat(StrEnum):
...@@ -15,7 +15,7 @@ class BoundingBoxFormat(StrEnum): ...@@ -15,7 +15,7 @@ class BoundingBoxFormat(StrEnum):
CXCYWH = StrEnum.auto() CXCYWH = StrEnum.auto()
class BoundingBox(_Feature): class BoundingBox(Datapoint):
format: BoundingBoxFormat format: BoundingBoxFormat
spatial_size: Tuple[int, int] spatial_size: Tuple[int, int]
......
...@@ -10,16 +10,12 @@ from torch.types import _device, _dtype, _size ...@@ -10,16 +10,12 @@ from torch.types import _device, _dtype, _size
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature") D = TypeVar("D", bound="Datapoint")
FillType = Union[int, float, Sequence[int], Sequence[float], None] FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None] FillTypeJIT = Union[int, float, List[float], None]
def is_simple_tensor(inpt: Any) -> bool: class Datapoint(torch.Tensor):
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)
class _Feature(torch.Tensor):
__F: Optional[ModuleType] = None __F: Optional[ModuleType] = None
@staticmethod @staticmethod
...@@ -31,22 +27,22 @@ class _Feature(torch.Tensor): ...@@ -31,22 +27,22 @@ class _Feature(torch.Tensor):
) -> torch.Tensor: ) -> torch.Tensor:
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the _Feature directly to have a # FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be # a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
# interpreted as images. We should decide if we want a public no-op feature like `GenericFeature` or make this one # interpreted as images. We should decide if we want a public no-op datapoint like `GenericDatapoint` or make this
# public again. # one public again.
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> _Feature: ) -> Datapoint:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return tensor.as_subclass(_Feature) return tensor.as_subclass(Datapoint)
@classmethod @classmethod
def wrap_like(cls: Type[F], other: F, tensor: torch.Tensor) -> F: def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved, # FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
# this method should be made abstract # this method should be made abstract
# raise NotImplementedError # raise NotImplementedError
...@@ -75,15 +71,15 @@ class _Feature(torch.Tensor): ...@@ -75,15 +71,15 @@ class _Feature(torch.Tensor):
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call. ``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`_Feature` The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
use case, this has two downsides: use case, this has two downsides:
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e. 1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them. ``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output. 2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`~_Feature._NO_WRAPPING_EXCEPTIONS` listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
""" """
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality. # need to reimplement the functionality.
...@@ -98,9 +94,9 @@ class _Feature(torch.Tensor): ...@@ -98,9 +94,9 @@ class _Feature(torch.Tensor):
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(features.Image(...))` will invoke `features.Image.__torch_function__` with # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), features.Image())` first. Without this guard, the original `torch.Tensor` would # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `features.Image`. # be wrapped into a `datapoints.Image`.
if wrapper and isinstance(args[0], cls): if wrapper and isinstance(args[0], cls):
return wrapper(cls, args[0], output) # type: ignore[no-any-return] return wrapper(cls, args[0], output) # type: ignore[no-any-return]
...@@ -123,11 +119,11 @@ class _Feature(torch.Tensor): ...@@ -123,11 +119,11 @@ class _Feature(torch.Tensor):
# until the first time we need reference to the functional module and it's shared across all instances of # until the first time we need reference to the functional module and it's shared across all instances of
# the class. This approach avoids the DataLoader issue described at # the class. This approach avoids the DataLoader issue described at
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621 # https://github.com/pytorch/vision/pull/6476#discussion_r953588621
if _Feature.__F is None: if Datapoint.__F is None:
from ..transforms import functional from ..transforms import functional
_Feature.__F = functional Datapoint.__F = functional
return _Feature.__F return Datapoint.__F
# Add properties for common attributes like shape, dtype, device, ndim etc # Add properties for common attributes like shape, dtype, device, ndim etc
# this way we return the result without passing into __torch_function__ # this way we return the result without passing into __torch_function__
...@@ -151,10 +147,10 @@ class _Feature(torch.Tensor): ...@@ -151,10 +147,10 @@ class _Feature(torch.Tensor):
with DisableTorchFunction(): with DisableTorchFunction():
return super().dtype return super().dtype
def horizontal_flip(self) -> _Feature: def horizontal_flip(self) -> Datapoint:
return self return self
def vertical_flip(self) -> _Feature: def vertical_flip(self) -> Datapoint:
return self return self
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
...@@ -165,13 +161,13 @@ class _Feature(torch.Tensor): ...@@ -165,13 +161,13 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> _Feature: ) -> Datapoint:
return self return self
def crop(self, top: int, left: int, height: int, width: int) -> _Feature: def crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
return self return self
def center_crop(self, output_size: List[int]) -> _Feature: def center_crop(self, output_size: List[int]) -> Datapoint:
return self return self
def resized_crop( def resized_crop(
...@@ -183,7 +179,7 @@ class _Feature(torch.Tensor): ...@@ -183,7 +179,7 @@ class _Feature(torch.Tensor):
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> _Feature: ) -> Datapoint:
return self return self
def pad( def pad(
...@@ -191,7 +187,7 @@ class _Feature(torch.Tensor): ...@@ -191,7 +187,7 @@ class _Feature(torch.Tensor):
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> _Feature: ) -> Datapoint:
return self return self
def rotate( def rotate(
...@@ -201,7 +197,7 @@ class _Feature(torch.Tensor): ...@@ -201,7 +197,7 @@ class _Feature(torch.Tensor):
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> _Feature: ) -> Datapoint:
return self return self
def affine( def affine(
...@@ -213,7 +209,7 @@ class _Feature(torch.Tensor): ...@@ -213,7 +209,7 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> _Feature: ) -> Datapoint:
return self return self
def perspective( def perspective(
...@@ -223,7 +219,7 @@ class _Feature(torch.Tensor): ...@@ -223,7 +219,7 @@ class _Feature(torch.Tensor):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> _Feature: ) -> Datapoint:
return self return self
def elastic( def elastic(
...@@ -231,45 +227,45 @@ class _Feature(torch.Tensor): ...@@ -231,45 +227,45 @@ class _Feature(torch.Tensor):
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> _Feature: ) -> Datapoint:
return self return self
def adjust_brightness(self, brightness_factor: float) -> _Feature: def adjust_brightness(self, brightness_factor: float) -> Datapoint:
return self return self
def adjust_saturation(self, saturation_factor: float) -> _Feature: def adjust_saturation(self, saturation_factor: float) -> Datapoint:
return self return self
def adjust_contrast(self, contrast_factor: float) -> _Feature: def adjust_contrast(self, contrast_factor: float) -> Datapoint:
return self return self
def adjust_sharpness(self, sharpness_factor: float) -> _Feature: def adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
return self return self
def adjust_hue(self, hue_factor: float) -> _Feature: def adjust_hue(self, hue_factor: float) -> Datapoint:
return self return self
def adjust_gamma(self, gamma: float, gain: float = 1) -> _Feature: def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
return self return self
def posterize(self, bits: int) -> _Feature: def posterize(self, bits: int) -> Datapoint:
return self return self
def solarize(self, threshold: float) -> _Feature: def solarize(self, threshold: float) -> Datapoint:
return self return self
def autocontrast(self) -> _Feature: def autocontrast(self) -> Datapoint:
return self return self
def equalize(self) -> _Feature: def equalize(self) -> Datapoint:
return self return self
def invert(self) -> _Feature: def invert(self) -> Datapoint:
return self return self
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint:
return self return self
InputType = Union[torch.Tensor, PIL.Image.Image, _Feature] InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
InputTypeJIT = torch.Tensor InputTypeJIT = torch.Tensor
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ._feature import _Feature, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
class ColorSpace(StrEnum): class ColorSpace(StrEnum):
...@@ -57,7 +57,7 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: ...@@ -57,7 +57,7 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace:
return ColorSpace.OTHER return ColorSpace.OTHER
class Image(_Feature): class Image(Datapoint):
color_space: ColorSpace color_space: ColorSpace
@classmethod @classmethod
......
...@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union ...@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from ._feature import _Feature from ._datapoint import Datapoint
L = TypeVar("L", bound="_LabelBase") L = TypeVar("L", bound="_LabelBase")
class _LabelBase(_Feature): class _LabelBase(Datapoint):
categories: Optional[Sequence[str]] categories: Optional[Sequence[str]]
@classmethod @classmethod
......
...@@ -5,10 +5,10 @@ from typing import Any, List, Optional, Tuple, Union ...@@ -5,10 +5,10 @@ from typing import Any, List, Optional, Tuple, Union
import torch import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from ._feature import _Feature, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
class Mask(_Feature): class Mask(Datapoint):
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor) -> Mask: def _wrap(cls, tensor: torch.Tensor) -> Mask:
return tensor.as_subclass(cls) return tensor.as_subclass(cls)
......
...@@ -6,11 +6,11 @@ from typing import Any, List, Optional, Tuple, Union ...@@ -6,11 +6,11 @@ from typing import Any, List, Optional, Tuple, Union
import torch import torch
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ._feature import _Feature, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
from ._image import ColorSpace from ._image import ColorSpace
class Video(_Feature): class Video(Datapoint):
color_space: ColorSpace color_space: ColorSpace
@classmethod @classmethod
......
...@@ -4,6 +4,8 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union ...@@ -4,6 +4,8 @@ from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
...@@ -12,7 +14,6 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -12,7 +14,6 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file, read_categories_file,
read_mat, read_mat,
) )
from torchvision.prototype.features import _Feature, BoundingBox, Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
...@@ -114,7 +115,7 @@ class Caltech101(Dataset): ...@@ -114,7 +115,7 @@ class Caltech101(Dataset):
format="xyxy", format="xyxy",
spatial_size=image.spatial_size, spatial_size=image.spatial_size,
), ),
contour=_Feature(ann["obj_contour"].T), contour=Datapoint(ann["obj_contour"].T),
) )
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
......
...@@ -3,6 +3,8 @@ import pathlib ...@@ -3,6 +3,8 @@ import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -11,7 +13,6 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -11,7 +13,6 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
path_accessor, path_accessor,
) )
from torchvision.prototype.features import _Feature, BoundingBox, Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
...@@ -148,7 +149,7 @@ class CelebA(Dataset): ...@@ -148,7 +149,7 @@ class CelebA(Dataset):
spatial_size=image.spatial_size, spatial_size=image.spatial_size,
), ),
landmarks={ landmarks={
landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) landmark: Datapoint((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()} for landmark in {key[:-2] for key in landmarks.keys()}
}, },
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment