Unverified Commit 01d138d8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

update naming feature -> datapoint in prototype test suite (#7117)

parent d7e5b6a1
......@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
self.pil_kernel_info = pil_kernel_info
kernel_infos = {}
for feature_type, kernel in self.kernels.items():
for datapoint_type, kernel in self.kernels.items():
kernel_info = self._KERNEL_INFO_MAP.get(kernel)
if not kernel_info:
raise pytest.UsageError(
f"Can't register {kernel.__name__} for type {feature_type} since there is no `KernelInfo` for it. "
f"Can't register {kernel.__name__} for type {datapoint_type} since there is no `KernelInfo` for it. "
f"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
)
kernel_infos[feature_type] = kernel_info
kernel_infos[datapoint_type] = kernel_info
self.kernel_infos = kernel_infos
def sample_inputs(self, *feature_types, filter_metadata=True):
for feature_type in feature_types or self.kernel_infos.keys():
kernel_info = self.kernel_infos.get(feature_type)
def sample_inputs(self, *datapoint_types, filter_metadata=True):
for datapoint_type in datapoint_types or self.kernel_infos.keys():
kernel_info = self.kernel_infos.get(datapoint_type)
if not kernel_info:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
......@@ -66,7 +66,7 @@ class DispatcherInfo(InfoBase):
yield from sample_inputs
else:
for args_kwargs in sample_inputs:
for attribute in feature_type.__annotations__.keys():
for attribute in datapoint_type.__annotations__.keys():
if attribute in args_kwargs.kwargs:
del args_kwargs.kwargs[attribute]
......@@ -107,9 +107,9 @@ def xfail_jit_list_of_ints(name, *, reason=None):
)
skip_dispatch_feature = TestMark(
("TestDispatchers", "test_dispatch_feature"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary feature dispatch."),
skip_dispatch_datapoint = TestMark(
("TestDispatchers", "test_dispatch_datapoint"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
)
......@@ -352,7 +352,7 @@ DISPATCHER_INFOS = [
},
pil_kernel_info=PILKernelInfo(F.erase_image_pil),
test_marks=[
skip_dispatch_feature,
skip_dispatch_datapoint,
],
),
DispatcherInfo(
......@@ -404,7 +404,7 @@ DISPATCHER_INFOS = [
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
test_marks=[
xfail_jit_python_scalar_arg("size"),
skip_dispatch_feature,
skip_dispatch_datapoint,
],
),
DispatcherInfo(
......@@ -415,7 +415,7 @@ DISPATCHER_INFOS = [
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
skip_dispatch_feature,
skip_dispatch_datapoint,
],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
),
......@@ -437,7 +437,7 @@ DISPATCHER_INFOS = [
datapoints.Video: F.convert_dtype_video,
},
test_marks=[
skip_dispatch_feature,
skip_dispatch_datapoint,
],
),
DispatcherInfo(
......@@ -446,7 +446,7 @@ DISPATCHER_INFOS = [
datapoints.Video: F.uniform_temporal_subsample_video,
},
test_marks=[
skip_dispatch_feature,
skip_dispatch_datapoint,
],
),
]
......@@ -28,7 +28,7 @@ def test_to_wrapping():
assert label_to.categories is label.categories
def test_to_feature_reference():
def test_to_datapoint_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
......
......@@ -285,7 +285,7 @@ class TestRandomHorizontalFlip:
assert_equal(expected, pil_to_tensor(actual))
def test_features_image(self, p):
def test_datapoints_image(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)
......@@ -293,7 +293,7 @@ class TestRandomHorizontalFlip:
assert_equal(datapoints.Image(expected), actual)
def test_features_mask(self, p):
def test_datapoints_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)
......@@ -301,7 +301,7 @@ class TestRandomHorizontalFlip:
assert_equal(datapoints.Mask(expected), actual)
def test_features_bounding_box(self, p):
def test_datapoints_bounding_box(self, p):
input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomHorizontalFlip(p=p)
......@@ -338,7 +338,7 @@ class TestRandomVerticalFlip:
assert_equal(expected, pil_to_tensor(actual))
def test_features_image(self, p):
def test_datapoints_image(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)
......@@ -346,7 +346,7 @@ class TestRandomVerticalFlip:
assert_equal(datapoints.Image(expected), actual)
def test_features_mask(self, p):
def test_datapoints_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)
......@@ -354,7 +354,7 @@ class TestRandomVerticalFlip:
assert_equal(datapoints.Mask(expected), actual)
def test_features_bounding_box(self, p):
def test_datapoints_bounding_box(self, p):
input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomVerticalFlip(p=p)
......
......@@ -558,15 +558,15 @@ def check_call_consistency(
output_prototype_image = prototype_transform(image)
except Exception as exc:
raise AssertionError(
f"Transforming a feature image with shape {image_repr} failed in the prototype transform with "
f"Transforming a image datapoint with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`features.Image` path in `_transform`."
f"`datapoints.Image` path in `_transform`."
) from exc
assert_close(
output_prototype_image,
output_prototype_tensor,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
msg=lambda msg: f"Output for datapoint and tensor images is not equal: \n\n{msg}",
**closeness_kwargs,
)
......@@ -931,7 +931,7 @@ class TestRefDetTransforms:
yield (tensor_image, target)
feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB)
datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB)
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
......@@ -939,7 +939,7 @@ class TestRefDetTransforms:
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (feature_image, target)
yield (datapoint_image, target)
@pytest.mark.parametrize(
"t_ref, t, data_kwargs",
......@@ -1015,13 +1015,13 @@ class TestRefSegTransforms:
conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns:
feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype)
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype)
datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(feature_image), feature_mask)
dp = (conv_fn(datapoint_image), datapoint_mask)
dp_ref = (
to_image_pil(feature_image) if supports_pil else feature_image.as_subclass(torch.Tensor),
to_image_pil(feature_mask),
to_image_pil(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor),
to_image_pil(datapoint_mask),
)
yield dp, dp_ref
......
......@@ -162,7 +162,7 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = (
datapoint_type = (
datapoints.Image
if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input)
else type(batched_input)
......@@ -178,10 +178,10 @@ class TestKernels:
# common ground.
datapoints.Mask: 2,
datapoints.Video: 4,
}.get(feature_type)
}.get(datapoint_type)
if data_dims is None:
raise pytest.UsageError(
f"The number of data dimensions cannot be determined for input of type {feature_type.__name__}."
f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
) from None
elif batched_input.ndim <= data_dims:
pytest.skip("Input is not batched.")
......@@ -323,8 +323,8 @@ class TestDispatchers:
def test_scripted_smoke(self, info, args_kwargs, device):
dispatcher = script(info.dispatcher)
(image_feature, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_feature)
(image_datapoint, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_datapoint)
dispatcher(image_simple_tensor, *other_args, **kwargs)
......@@ -352,8 +352,8 @@ class TestDispatchers:
@image_sample_inputs
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_feature)
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_datapoint)
kernel_info = info.kernel_infos[datapoints.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
......@@ -367,12 +367,12 @@ class TestDispatchers:
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)
def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load()
(image_datapoint, *other_args), kwargs = args_kwargs.load()
if image_feature.ndim > 3:
if image_datapoint.ndim > 3:
pytest.skip("Input is batched")
image_pil = F.to_image_pil(image_feature)
image_pil = F.to_image_pil(image_datapoint)
pil_kernel_info = info.pil_kernel_info
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
......@@ -385,37 +385,39 @@ class TestDispatchers:
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_dispatch_feature(self, info, args_kwargs, spy_on):
(feature, *other_args), kwargs = args_kwargs.load()
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(datapoint, *other_args), kwargs = args_kwargs.load()
method_name = info.id
method = getattr(feature, method_name)
feature_type = type(feature)
spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{method_name}")
method = getattr(datapoint, method_name)
datapoint_type = type(datapoint)
spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
info.dispatcher(feature, *other_args, **kwargs)
info.dispatcher(datapoint, *other_args, **kwargs)
spy.assert_called_once()
@pytest.mark.parametrize(
("dispatcher_info", "feature_type", "kernel_info"),
("dispatcher_info", "datapoint_type", "kernel_info"),
[
pytest.param(dispatcher_info, feature_type, kernel_info, id=f"{dispatcher_info.id}-{feature_type.__name__}")
pytest.param(
dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
)
for dispatcher_info in DISPATCHER_INFOS
for feature_type, kernel_info in dispatcher_info.kernel_infos.items()
for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
],
)
def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature_type, kernel_info):
def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
kernel_signature = inspect.signature(kernel_info.kernel)
kernel_params = list(kernel_signature.parameters.values())[1:]
# We filter out metadata that is implicitly passed to the dispatcher through the input feature, but has to be
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# explicit passed to the kernel.
feature_type_metadata = feature_type.__annotations__.keys()
kernel_params = [param for param in kernel_params if param.name not in feature_type_metadata]
datapoint_type_metadata = datapoint_type.__annotations__.keys()
kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata]
dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
......@@ -433,26 +435,26 @@ class TestDispatchers:
assert dispatcher_param == kernel_param
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_feature_signatures_consistency(self, info):
def test_dispatcher_datapoint_signatures_consistency(self, info):
try:
feature_method = getattr(datapoints._datapoint.Datapoint, info.id)
datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.")
pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")
dispatcher_signature = inspect.signature(info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
feature_signature = inspect.signature(feature_method)
feature_params = list(feature_signature.parameters.values())[1:]
datapoint_signature = inspect.signature(datapoint_method)
datapoint_params = list(datapoint_signature.parameters.values())[1:]
# Because we use `from __future__ import annotations` inside the module where `features._datapoint` is defined,
# the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively
# concrete dispatcher annotations.
feature_annotations = get_type_hints(feature_method)
for param in feature_params:
param._annotation = feature_annotations[param.name]
# Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
# defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
# natively concrete dispatcher annotations.
datapoint_annotations = get_type_hints(datapoint_method)
for param in datapoint_params:
param._annotation = datapoint_annotations[param.name]
assert dispatcher_params == feature_params
assert dispatcher_params == datapoint_params
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_unkown_type(self, info):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment