"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "33d10af28fcfb4d41ab7fb97d84c8ac2317576d5"
Unverified Commit 01d138d8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

update naming feature -> datapoint in prototype test suite (#7117)

parent d7e5b6a1
...@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase): ...@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
self.pil_kernel_info = pil_kernel_info self.pil_kernel_info = pil_kernel_info
kernel_infos = {} kernel_infos = {}
for feature_type, kernel in self.kernels.items(): for datapoint_type, kernel in self.kernels.items():
kernel_info = self._KERNEL_INFO_MAP.get(kernel) kernel_info = self._KERNEL_INFO_MAP.get(kernel)
if not kernel_info: if not kernel_info:
raise pytest.UsageError( raise pytest.UsageError(
f"Can't register {kernel.__name__} for type {feature_type} since there is no `KernelInfo` for it. " f"Can't register {kernel.__name__} for type {datapoint_type} since there is no `KernelInfo` for it. "
f"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`." f"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
) )
kernel_infos[feature_type] = kernel_info kernel_infos[datapoint_type] = kernel_info
self.kernel_infos = kernel_infos self.kernel_infos = kernel_infos
def sample_inputs(self, *feature_types, filter_metadata=True): def sample_inputs(self, *datapoint_types, filter_metadata=True):
for feature_type in feature_types or self.kernel_infos.keys(): for datapoint_type in datapoint_types or self.kernel_infos.keys():
kernel_info = self.kernel_infos.get(feature_type) kernel_info = self.kernel_infos.get(datapoint_type)
if not kernel_info: if not kernel_info:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}") raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
...@@ -66,7 +66,7 @@ class DispatcherInfo(InfoBase): ...@@ -66,7 +66,7 @@ class DispatcherInfo(InfoBase):
yield from sample_inputs yield from sample_inputs
else: else:
for args_kwargs in sample_inputs: for args_kwargs in sample_inputs:
for attribute in feature_type.__annotations__.keys(): for attribute in datapoint_type.__annotations__.keys():
if attribute in args_kwargs.kwargs: if attribute in args_kwargs.kwargs:
del args_kwargs.kwargs[attribute] del args_kwargs.kwargs[attribute]
...@@ -107,9 +107,9 @@ def xfail_jit_list_of_ints(name, *, reason=None): ...@@ -107,9 +107,9 @@ def xfail_jit_list_of_ints(name, *, reason=None):
) )
skip_dispatch_feature = TestMark( skip_dispatch_datapoint = TestMark(
("TestDispatchers", "test_dispatch_feature"), ("TestDispatchers", "test_dispatch_datapoint"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary feature dispatch."), pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
) )
...@@ -352,7 +352,7 @@ DISPATCHER_INFOS = [ ...@@ -352,7 +352,7 @@ DISPATCHER_INFOS = [
}, },
pil_kernel_info=PILKernelInfo(F.erase_image_pil), pil_kernel_info=PILKernelInfo(F.erase_image_pil),
test_marks=[ test_marks=[
skip_dispatch_feature, skip_dispatch_datapoint,
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -404,7 +404,7 @@ DISPATCHER_INFOS = [ ...@@ -404,7 +404,7 @@ DISPATCHER_INFOS = [
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
skip_dispatch_feature, skip_dispatch_datapoint,
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -415,7 +415,7 @@ DISPATCHER_INFOS = [ ...@@ -415,7 +415,7 @@ DISPATCHER_INFOS = [
}, },
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
skip_dispatch_feature, skip_dispatch_datapoint,
], ],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil), pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
), ),
...@@ -437,7 +437,7 @@ DISPATCHER_INFOS = [ ...@@ -437,7 +437,7 @@ DISPATCHER_INFOS = [
datapoints.Video: F.convert_dtype_video, datapoints.Video: F.convert_dtype_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_feature, skip_dispatch_datapoint,
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -446,7 +446,7 @@ DISPATCHER_INFOS = [ ...@@ -446,7 +446,7 @@ DISPATCHER_INFOS = [
datapoints.Video: F.uniform_temporal_subsample_video, datapoints.Video: F.uniform_temporal_subsample_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_feature, skip_dispatch_datapoint,
], ],
), ),
] ]
...@@ -28,7 +28,7 @@ def test_to_wrapping(): ...@@ -28,7 +28,7 @@ def test_to_wrapping():
assert label_to.categories is label.categories assert label_to.categories is label.categories
def test_to_feature_reference(): def test_to_datapoint_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32) label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
......
...@@ -285,7 +285,7 @@ class TestRandomHorizontalFlip: ...@@ -285,7 +285,7 @@ class TestRandomHorizontalFlip:
assert_equal(expected, pil_to_tensor(actual)) assert_equal(expected, pil_to_tensor(actual))
def test_features_image(self, p): def test_datapoints_image(self, p):
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p) transform = transforms.RandomHorizontalFlip(p=p)
...@@ -293,7 +293,7 @@ class TestRandomHorizontalFlip: ...@@ -293,7 +293,7 @@ class TestRandomHorizontalFlip:
assert_equal(datapoints.Image(expected), actual) assert_equal(datapoints.Image(expected), actual)
def test_features_mask(self, p): def test_datapoints_mask(self, p):
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p) transform = transforms.RandomHorizontalFlip(p=p)
...@@ -301,7 +301,7 @@ class TestRandomHorizontalFlip: ...@@ -301,7 +301,7 @@ class TestRandomHorizontalFlip:
assert_equal(datapoints.Mask(expected), actual) assert_equal(datapoints.Mask(expected), actual)
def test_features_bounding_box(self, p): def test_datapoints_bounding_box(self, p):
input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomHorizontalFlip(p=p) transform = transforms.RandomHorizontalFlip(p=p)
...@@ -338,7 +338,7 @@ class TestRandomVerticalFlip: ...@@ -338,7 +338,7 @@ class TestRandomVerticalFlip:
assert_equal(expected, pil_to_tensor(actual)) assert_equal(expected, pil_to_tensor(actual))
def test_features_image(self, p): def test_datapoints_image(self, p):
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p) transform = transforms.RandomVerticalFlip(p=p)
...@@ -346,7 +346,7 @@ class TestRandomVerticalFlip: ...@@ -346,7 +346,7 @@ class TestRandomVerticalFlip:
assert_equal(datapoints.Image(expected), actual) assert_equal(datapoints.Image(expected), actual)
def test_features_mask(self, p): def test_datapoints_mask(self, p):
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p) transform = transforms.RandomVerticalFlip(p=p)
...@@ -354,7 +354,7 @@ class TestRandomVerticalFlip: ...@@ -354,7 +354,7 @@ class TestRandomVerticalFlip:
assert_equal(datapoints.Mask(expected), actual) assert_equal(datapoints.Mask(expected), actual)
def test_features_bounding_box(self, p): def test_datapoints_bounding_box(self, p):
input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomVerticalFlip(p=p) transform = transforms.RandomVerticalFlip(p=p)
......
...@@ -558,15 +558,15 @@ def check_call_consistency( ...@@ -558,15 +558,15 @@ def check_call_consistency(
output_prototype_image = prototype_transform(image) output_prototype_image = prototype_transform(image)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a feature image with shape {image_repr} failed in the prototype transform with " f"Transforming a image datapoint with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the " f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`features.Image` path in `_transform`." f"`datapoints.Image` path in `_transform`."
) from exc ) from exc
assert_close( assert_close(
output_prototype_image, output_prototype_image,
output_prototype_tensor, output_prototype_tensor,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}", msg=lambda msg: f"Output for datapoint and tensor images is not equal: \n\n{msg}",
**closeness_kwargs, **closeness_kwargs,
) )
...@@ -931,7 +931,7 @@ class TestRefDetTransforms: ...@@ -931,7 +931,7 @@ class TestRefDetTransforms:
yield (tensor_image, target) yield (tensor_image, target)
feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB) datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB)
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -939,7 +939,7 @@ class TestRefDetTransforms: ...@@ -939,7 +939,7 @@ class TestRefDetTransforms:
if with_mask: if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (feature_image, target) yield (datapoint_image, target)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"t_ref, t, data_kwargs", "t_ref, t, data_kwargs",
...@@ -1015,13 +1015,13 @@ class TestRefSegTransforms: ...@@ -1015,13 +1015,13 @@ class TestRefSegTransforms:
conv_fns.extend([torch.Tensor, lambda x: x]) conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns: for conv_fn in conv_fns:
feature_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype) datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype)
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(feature_image), feature_mask) dp = (conv_fn(datapoint_image), datapoint_mask)
dp_ref = ( dp_ref = (
to_image_pil(feature_image) if supports_pil else feature_image.as_subclass(torch.Tensor), to_image_pil(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor),
to_image_pil(feature_mask), to_image_pil(datapoint_mask),
) )
yield dp, dp_ref yield dp, dp_ref
......
...@@ -162,7 +162,7 @@ class TestKernels: ...@@ -162,7 +162,7 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device): def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device) (batched_input, *other_args), kwargs = args_kwargs.load(device)
feature_type = ( datapoint_type = (
datapoints.Image datapoints.Image
if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input) if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input)
else type(batched_input) else type(batched_input)
...@@ -178,10 +178,10 @@ class TestKernels: ...@@ -178,10 +178,10 @@ class TestKernels:
# common ground. # common ground.
datapoints.Mask: 2, datapoints.Mask: 2,
datapoints.Video: 4, datapoints.Video: 4,
}.get(feature_type) }.get(datapoint_type)
if data_dims is None: if data_dims is None:
raise pytest.UsageError( raise pytest.UsageError(
f"The number of data dimensions cannot be determined for input of type {feature_type.__name__}." f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
) from None ) from None
elif batched_input.ndim <= data_dims: elif batched_input.ndim <= data_dims:
pytest.skip("Input is not batched.") pytest.skip("Input is not batched.")
...@@ -323,8 +323,8 @@ class TestDispatchers: ...@@ -323,8 +323,8 @@ class TestDispatchers:
def test_scripted_smoke(self, info, args_kwargs, device): def test_scripted_smoke(self, info, args_kwargs, device):
dispatcher = script(info.dispatcher) dispatcher = script(info.dispatcher)
(image_feature, *other_args), kwargs = args_kwargs.load(device) (image_datapoint, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_feature) image_simple_tensor = torch.Tensor(image_datapoint)
dispatcher(image_simple_tensor, *other_args, **kwargs) dispatcher(image_simple_tensor, *other_args, **kwargs)
...@@ -352,8 +352,8 @@ class TestDispatchers: ...@@ -352,8 +352,8 @@ class TestDispatchers:
@image_sample_inputs @image_sample_inputs
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load() (image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_feature) image_simple_tensor = torch.Tensor(image_datapoint)
kernel_info = info.kernel_infos[datapoints.Image] kernel_info = info.kernel_infos[datapoints.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id) spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
...@@ -367,12 +367,12 @@ class TestDispatchers: ...@@ -367,12 +367,12 @@ class TestDispatchers:
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
) )
def test_dispatch_pil(self, info, args_kwargs, spy_on): def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_feature, *other_args), kwargs = args_kwargs.load() (image_datapoint, *other_args), kwargs = args_kwargs.load()
if image_feature.ndim > 3: if image_datapoint.ndim > 3:
pytest.skip("Input is batched") pytest.skip("Input is batched")
image_pil = F.to_image_pil(image_feature) image_pil = F.to_image_pil(image_datapoint)
pil_kernel_info = info.pil_kernel_info pil_kernel_info = info.pil_kernel_info
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id) spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
...@@ -385,37 +385,39 @@ class TestDispatchers: ...@@ -385,37 +385,39 @@ class TestDispatchers:
DISPATCHER_INFOS, DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(), args_kwargs_fn=lambda info: info.sample_inputs(),
) )
def test_dispatch_feature(self, info, args_kwargs, spy_on): def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(feature, *other_args), kwargs = args_kwargs.load() (datapoint, *other_args), kwargs = args_kwargs.load()
method_name = info.id method_name = info.id
method = getattr(feature, method_name) method = getattr(datapoint, method_name)
feature_type = type(feature) datapoint_type = type(datapoint)
spy = spy_on(method, module=feature_type.__module__, name=f"{feature_type.__name__}.{method_name}") spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
info.dispatcher(feature, *other_args, **kwargs) info.dispatcher(datapoint, *other_args, **kwargs)
spy.assert_called_once() spy.assert_called_once()
@pytest.mark.parametrize( @pytest.mark.parametrize(
("dispatcher_info", "feature_type", "kernel_info"), ("dispatcher_info", "datapoint_type", "kernel_info"),
[ [
pytest.param(dispatcher_info, feature_type, kernel_info, id=f"{dispatcher_info.id}-{feature_type.__name__}") pytest.param(
dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
)
for dispatcher_info in DISPATCHER_INFOS for dispatcher_info in DISPATCHER_INFOS
for feature_type, kernel_info in dispatcher_info.kernel_infos.items() for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
], ],
) )
def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature_type, kernel_info): def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
dispatcher_signature = inspect.signature(dispatcher_info.dispatcher) dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:] dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
kernel_signature = inspect.signature(kernel_info.kernel) kernel_signature = inspect.signature(kernel_info.kernel)
kernel_params = list(kernel_signature.parameters.values())[1:] kernel_params = list(kernel_signature.parameters.values())[1:]
# We filter out metadata that is implicitly passed to the dispatcher through the input feature, but has to be # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# explicit passed to the kernel. # explicit passed to the kernel.
feature_type_metadata = feature_type.__annotations__.keys() datapoint_type_metadata = datapoint_type.__annotations__.keys()
kernel_params = [param for param in kernel_params if param.name not in feature_type_metadata] kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata]
dispatcher_params = iter(dispatcher_params) dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
...@@ -433,26 +435,26 @@ class TestDispatchers: ...@@ -433,26 +435,26 @@ class TestDispatchers:
assert dispatcher_param == kernel_param assert dispatcher_param == kernel_param
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_feature_signatures_consistency(self, info): def test_dispatcher_datapoint_signatures_consistency(self, info):
try: try:
feature_method = getattr(datapoints._datapoint.Datapoint, info.id) datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
except AttributeError: except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.") pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")
dispatcher_signature = inspect.signature(info.dispatcher) dispatcher_signature = inspect.signature(info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:] dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
feature_signature = inspect.signature(feature_method) datapoint_signature = inspect.signature(datapoint_method)
feature_params = list(feature_signature.parameters.values())[1:] datapoint_params = list(datapoint_signature.parameters.values())[1:]
# Because we use `from __future__ import annotations` inside the module where `features._datapoint` is defined, # Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
# the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively # defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
# concrete dispatcher annotations. # natively concrete dispatcher annotations.
feature_annotations = get_type_hints(feature_method) datapoint_annotations = get_type_hints(datapoint_method)
for param in feature_params: for param in datapoint_params:
param._annotation = feature_annotations[param.name] param._annotation = datapoint_annotations[param.name]
assert dispatcher_params == feature_params assert dispatcher_params == datapoint_params
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_unkown_type(self, info): def test_unkown_type(self, info):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment