Unverified Commit f9d1883e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Fix some annotations in transforms v2 for JIT v1 compatibility (#7252)

parent f6b5b82e
......@@ -3,7 +3,7 @@ import collections.abc
import pytest
import torchvision.prototype.transforms.functional as F
from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
from torchvision.prototype import datapoints
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
......@@ -96,25 +96,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
)
def xfail_jit_tuple_instead_of_list(name, *, reason=None):
return xfail_jit(
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting",
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
)
def is_list_of_ints(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)
def xfail_jit_list_of_ints(name, *, reason=None):
return xfail_jit(
reason or f"Passing a list of integers for `{name}` is not supported when scripting",
condition=is_list_of_ints,
)
skip_dispatch_datapoint = TestMark(
("TestDispatchers", "test_dispatch_datapoint"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
......@@ -130,6 +111,13 @@ multi_crop_skips = [
multi_crop_skips.append(skip_dispatch_datapoint)
def xfails_pil(reason, *, condition=None):
return [
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in ["test_dispatch_pil", "test_pil_output_type"]
]
def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
try:
......@@ -143,11 +131,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
return image_loader.num_channels > 1
xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark(
("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger."
),
xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
"PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
condition=fill_sequence_needs_broadcast,
)
......@@ -186,11 +171,9 @@ DISPATCHER_INFOS = [
},
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
*xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("fill"),
],
),
DispatcherInfo(
......@@ -213,9 +196,8 @@ DISPATCHER_INFOS = [
},
pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("fill"),
*xfails_pil_if_fill_sequence_needs_broadcast,
],
),
DispatcherInfo(
......@@ -248,21 +230,16 @@ DISPATCHER_INFOS = [
},
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[
TestMark(
("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
)
*xfails_pil(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
),
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
xfail_jit_python_scalar_arg("padding"),
],
),
DispatcherInfo(
......@@ -275,7 +252,8 @@ DISPATCHER_INFOS = [
},
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
*xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("fill"),
],
),
DispatcherInfo(
......@@ -287,6 +265,7 @@ DISPATCHER_INFOS = [
datapoints.Mask: F.elastic_mask,
},
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
test_marks=[xfail_jit_python_scalar_arg("fill")],
),
DispatcherInfo(
F.center_crop,
......
......@@ -153,26 +153,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
)
def xfail_jit_tuple_instead_of_list(name, *, reason=None):
reason = reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting"
return xfail_jit(
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting",
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
)
def is_list_of_ints(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)
def xfail_jit_list_of_ints(name, *, reason=None):
return xfail_jit(
reason or f"Passing a list of integers for `{name}` is not supported when scripting",
condition=is_list_of_ints,
)
KERNEL_INFOS = []
......@@ -450,21 +430,21 @@ _DIVERSE_AFFINE_PARAMS = [
]
def get_fills(*, num_channels, dtype, vector=True):
def get_fills(*, num_channels, dtype):
yield None
max_value = get_max_value(dtype)
# This intentionally gives us a float and an int scalar fill value
yield max_value / 2
yield max_value
int_value = get_max_value(dtype)
float_value = int_value / 2
yield int_value
yield float_value
if not vector:
return
for vector_type in [list, tuple]:
yield vector_type([int_value])
yield vector_type([float_value])
if dtype.is_floating_point:
yield [0.1 + c / 10 for c in range(num_channels)]
else:
yield [12.0 + c for c in range(num_channels)]
if num_channels > 1:
yield vector_type(float_value * c / 10 for c in range(num_channels))
yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))
def float32_vs_uint8_fill_adapter(other_args, kwargs):
......@@ -644,9 +624,7 @@ KERNEL_INFOS.extend(
closeness_kwargs=pil_reference_pixel_difference(10, mae=True),
test_marks=[
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("fill"),
],
),
KernelInfo(
......@@ -873,9 +851,7 @@ KERNEL_INFOS.extend(
float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(1, mae=True),
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("fill"),
],
),
KernelInfo(
......@@ -1122,12 +1098,14 @@ def reference_inputs_pad_image_tensor():
for image_loader, params in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
for fill in get_fills(
num_channels=image_loader.num_channels,
dtype=image_loader.dtype,
vector=params["padding_mode"] == "constant",
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
if isinstance(fill, (list, tuple)):
continue
yield ArgsKwargs(image_loader, fill=fill, **params)
......@@ -1195,6 +1173,16 @@ def reference_inputs_pad_bounding_box():
)
def pad_xfail_jit_fill_condition(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
if not isinstance(fill, (list, tuple)):
return False
elif isinstance(fill, tuple):
return True
else: # isinstance(fill, list):
return all(isinstance(f, int) for f in fill)
KERNEL_INFOS.extend(
[
KernelInfo(
......@@ -1205,10 +1193,10 @@ KERNEL_INFOS.extend(
float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
test_marks=[
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
xfail_jit_python_scalar_arg("padding"),
xfail_jit(
"F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition
),
],
),
KernelInfo(
......@@ -1217,7 +1205,7 @@ KERNEL_INFOS.extend(
reference_fn=reference_pad_bounding_box,
reference_inputs_fn=reference_inputs_pad_bounding_box,
test_marks=[
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_python_scalar_arg("padding"),
],
),
KernelInfo(
......@@ -1261,8 +1249,11 @@ def reference_inputs_perspective_image_tensor():
F.InterpolationMode.BILINEAR,
],
):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
if isinstance(fill, (list, tuple)):
continue
yield ArgsKwargs(
image_loader,
startpoints=None,
......@@ -1327,6 +1318,7 @@ KERNEL_INFOS.extend(
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
},
test_marks=[xfail_jit_python_scalar_arg("fill")],
),
KernelInfo(
F.perspective_bounding_box,
......@@ -1418,6 +1410,7 @@ KERNEL_INFOS.extend(
**float32_vs_uint8_pixel_difference(6, mae=True),
**cuda_vs_cpu_pixel_difference(),
},
test_marks=[xfail_jit_python_scalar_arg("fill")],
),
KernelInfo(
F.elastic_bounding_box,
......
......@@ -118,7 +118,7 @@ class BoundingBox(Datapoint):
def pad(
self,
padding: Union[int, Sequence[int]],
fill: FillTypeJIT = None,
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> BoundingBox:
output, spatial_size = self._F.pad_bounding_box(
......
......@@ -12,7 +12,7 @@ from torchvision.transforms import InterpolationMode
D = TypeVar("D", bound="Datapoint")
FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None]
FillTypeJIT = Optional[List[float]]
class Datapoint(torch.Tensor):
......@@ -169,8 +169,8 @@ class Datapoint(torch.Tensor):
def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Datapoint:
return self
......
......@@ -103,8 +103,8 @@ class Image(Datapoint):
def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Image:
output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
......
......@@ -83,8 +83,8 @@ class Mask(Datapoint):
def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Mask:
output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
......
......@@ -102,8 +102,8 @@ class Video(Datapoint):
def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Video:
output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
......
......@@ -270,7 +270,7 @@ class Pad(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
class RandomZoomOut(_RandomApplyTransform):
......
......@@ -60,10 +60,9 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT:
if fill is None:
return fill
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill
return fill # type: ignore[return-value]
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]:
......
......@@ -432,7 +432,7 @@ def _apply_grid_transform(
if fill is not None:
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
mask = mask.expand_as(float_img)
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
if mode == "nearest":
bool_mask = mask < 0.5
......@@ -968,8 +968,8 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
def pad_image_tensor(
image: torch.Tensor,
padding: Union[int, List[int]],
fill: datapoints.FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> torch.Tensor:
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
......@@ -1069,14 +1069,14 @@ pad_image_pil = _FP.pad
def pad_mask(
mask: torch.Tensor,
padding: Union[int, List[int]],
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
fill: datapoints.FillTypeJIT = None,
) -> torch.Tensor:
if fill is None:
fill = 0
if isinstance(fill, list):
if isinstance(fill, (tuple, list)):
raise ValueError("Non-scalar fill value is not supported")
if mask.ndim < 3:
......@@ -1097,7 +1097,7 @@ def pad_bounding_box(
bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
padding: Union[int, List[int]],
padding: List[int],
padding_mode: str = "constant",
) -> Tuple[torch.Tensor, Tuple[int, int]]:
if padding_mode not in ["constant"]:
......@@ -1122,8 +1122,8 @@ def pad_bounding_box(
def pad_video(
video: torch.Tensor,
padding: Union[int, List[int]],
fill: datapoints.FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> torch.Tensor:
return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)
......@@ -1131,8 +1131,8 @@ def pad_video(
def pad(
inpt: datapoints.InputTypeJIT,
padding: Union[int, List[int]],
fill: datapoints.FillTypeJIT = None,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment