Unverified Commit f9d1883e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Fix some annotations in transforms v2 for JIT v1 compatibility (#7252)

parent f6b5b82e
...@@ -3,7 +3,7 @@ import collections.abc ...@@ -3,7 +3,7 @@ import collections.abc
import pytest import pytest
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from prototype_common_utils import InfoBase, TestMark from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
...@@ -96,25 +96,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): ...@@ -96,25 +96,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
) )
def xfail_jit_tuple_instead_of_list(name, *, reason=None):
return xfail_jit(
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting",
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
)
def is_list_of_ints(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)
def xfail_jit_list_of_ints(name, *, reason=None):
return xfail_jit(
reason or f"Passing a list of integers for `{name}` is not supported when scripting",
condition=is_list_of_ints,
)
skip_dispatch_datapoint = TestMark( skip_dispatch_datapoint = TestMark(
("TestDispatchers", "test_dispatch_datapoint"), ("TestDispatchers", "test_dispatch_datapoint"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."), pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
...@@ -130,6 +111,13 @@ multi_crop_skips = [ ...@@ -130,6 +111,13 @@ multi_crop_skips = [
multi_crop_skips.append(skip_dispatch_datapoint) multi_crop_skips.append(skip_dispatch_datapoint)
def xfails_pil(reason, *, condition=None):
return [
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in ["test_dispatch_pil", "test_pil_output_type"]
]
def fill_sequence_needs_broadcast(args_kwargs): def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs (image_loader, *_), kwargs = args_kwargs
try: try:
...@@ -143,11 +131,8 @@ def fill_sequence_needs_broadcast(args_kwargs): ...@@ -143,11 +131,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
return image_loader.num_channels > 1 return image_loader.num_channels > 1
xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark( xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
("TestDispatchers", "test_dispatch_pil"), "PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
pytest.mark.xfail(
reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger."
),
condition=fill_sequence_needs_broadcast, condition=fill_sequence_needs_broadcast,
) )
...@@ -186,11 +171,9 @@ DISPATCHER_INFOS = [ ...@@ -186,11 +171,9 @@ DISPATCHER_INFOS = [
}, },
pil_kernel_info=PILKernelInfo(F.affine_image_pil), pil_kernel_info=PILKernelInfo(F.affine_image_pil),
test_marks=[ test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast, *xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_python_scalar_arg("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -213,9 +196,8 @@ DISPATCHER_INFOS = [ ...@@ -213,9 +196,8 @@ DISPATCHER_INFOS = [
}, },
pil_kernel_info=PILKernelInfo(F.rotate_image_pil), pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_python_scalar_arg("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok *xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_list_of_ints("fill"),
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -248,21 +230,16 @@ DISPATCHER_INFOS = [ ...@@ -248,21 +230,16 @@ DISPATCHER_INFOS = [
}, },
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[ test_marks=[
TestMark( *xfails_pil(
("TestDispatchers", "test_dispatch_pil"), reason=(
pytest.mark.xfail( "PIL kernel doesn't support sequences of length 1 for argument `fill` and "
reason=( "`padding_mode='constant'`, if the number of color channels is larger."
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
)
), ),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs) condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant", and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
), ),
xfail_jit_tuple_instead_of_list("padding"), xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_python_scalar_arg("padding"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -275,7 +252,8 @@ DISPATCHER_INFOS = [ ...@@ -275,7 +252,8 @@ DISPATCHER_INFOS = [
}, },
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[ test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast, *xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("fill"),
], ],
), ),
DispatcherInfo( DispatcherInfo(
...@@ -287,6 +265,7 @@ DISPATCHER_INFOS = [ ...@@ -287,6 +265,7 @@ DISPATCHER_INFOS = [
datapoints.Mask: F.elastic_mask, datapoints.Mask: F.elastic_mask,
}, },
pil_kernel_info=PILKernelInfo(F.elastic_image_pil), pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
test_marks=[xfail_jit_python_scalar_arg("fill")],
), ),
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
......
...@@ -153,26 +153,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): ...@@ -153,26 +153,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
) )
def xfail_jit_tuple_instead_of_list(name, *, reason=None):
reason = reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting"
return xfail_jit(
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting",
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
)
def is_list_of_ints(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)
def xfail_jit_list_of_ints(name, *, reason=None):
return xfail_jit(
reason or f"Passing a list of integers for `{name}` is not supported when scripting",
condition=is_list_of_ints,
)
KERNEL_INFOS = [] KERNEL_INFOS = []
...@@ -450,21 +430,21 @@ _DIVERSE_AFFINE_PARAMS = [ ...@@ -450,21 +430,21 @@ _DIVERSE_AFFINE_PARAMS = [
] ]
def get_fills(*, num_channels, dtype, vector=True): def get_fills(*, num_channels, dtype):
yield None yield None
max_value = get_max_value(dtype) int_value = get_max_value(dtype)
# This intentionally gives us a float and an int scalar fill value float_value = int_value / 2
yield max_value / 2 yield int_value
yield max_value yield float_value
if not vector: for vector_type in [list, tuple]:
return yield vector_type([int_value])
yield vector_type([float_value])
if dtype.is_floating_point: if num_channels > 1:
yield [0.1 + c / 10 for c in range(num_channels)] yield vector_type(float_value * c / 10 for c in range(num_channels))
else: yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))
yield [12.0 + c for c in range(num_channels)]
def float32_vs_uint8_fill_adapter(other_args, kwargs): def float32_vs_uint8_fill_adapter(other_args, kwargs):
...@@ -644,9 +624,7 @@ KERNEL_INFOS.extend( ...@@ -644,9 +624,7 @@ KERNEL_INFOS.extend(
closeness_kwargs=pil_reference_pixel_difference(10, mae=True), closeness_kwargs=pil_reference_pixel_difference(10, mae=True),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_python_scalar_arg("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
], ],
), ),
KernelInfo( KernelInfo(
...@@ -873,9 +851,7 @@ KERNEL_INFOS.extend( ...@@ -873,9 +851,7 @@ KERNEL_INFOS.extend(
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs=pil_reference_pixel_difference(1, mae=True), closeness_kwargs=pil_reference_pixel_difference(1, mae=True),
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_python_scalar_arg("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
], ],
), ),
KernelInfo( KernelInfo(
...@@ -1122,12 +1098,14 @@ def reference_inputs_pad_image_tensor(): ...@@ -1122,12 +1098,14 @@ def reference_inputs_pad_image_tensor():
for image_loader, params in itertools.product( for image_loader, params in itertools.product(
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
): ):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
for fill in get_fills( for fill in get_fills(
num_channels=image_loader.num_channels, num_channels=image_loader.num_channels,
dtype=image_loader.dtype, dtype=image_loader.dtype,
vector=params["padding_mode"] == "constant",
): ):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
if isinstance(fill, (list, tuple)):
continue
yield ArgsKwargs(image_loader, fill=fill, **params) yield ArgsKwargs(image_loader, fill=fill, **params)
...@@ -1195,6 +1173,16 @@ def reference_inputs_pad_bounding_box(): ...@@ -1195,6 +1173,16 @@ def reference_inputs_pad_bounding_box():
) )
def pad_xfail_jit_fill_condition(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
if not isinstance(fill, (list, tuple)):
return False
elif isinstance(fill, tuple):
return True
else: # isinstance(fill, list):
return all(isinstance(f, int) for f in fill)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -1205,10 +1193,10 @@ KERNEL_INFOS.extend( ...@@ -1205,10 +1193,10 @@ KERNEL_INFOS.extend(
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(), closeness_kwargs=float32_vs_uint8_pixel_difference(),
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("padding"), xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit(
# TODO: check if this is a regression since it seems that should be supported if `int` is ok "F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition
xfail_jit_list_of_ints("fill"), ),
], ],
), ),
KernelInfo( KernelInfo(
...@@ -1217,7 +1205,7 @@ KERNEL_INFOS.extend( ...@@ -1217,7 +1205,7 @@ KERNEL_INFOS.extend(
reference_fn=reference_pad_bounding_box, reference_fn=reference_pad_bounding_box,
reference_inputs_fn=reference_inputs_pad_bounding_box, reference_inputs_fn=reference_inputs_pad_bounding_box,
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("padding"), xfail_jit_python_scalar_arg("padding"),
], ],
), ),
KernelInfo( KernelInfo(
...@@ -1261,8 +1249,11 @@ def reference_inputs_perspective_image_tensor(): ...@@ -1261,8 +1249,11 @@ def reference_inputs_perspective_image_tensor():
F.InterpolationMode.BILINEAR, F.InterpolationMode.BILINEAR,
], ],
): ):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
if isinstance(fill, (list, tuple)):
continue
yield ArgsKwargs( yield ArgsKwargs(
image_loader, image_loader,
startpoints=None, startpoints=None,
...@@ -1327,6 +1318,7 @@ KERNEL_INFOS.extend( ...@@ -1327,6 +1318,7 @@ KERNEL_INFOS.extend(
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
}, },
test_marks=[xfail_jit_python_scalar_arg("fill")],
), ),
KernelInfo( KernelInfo(
F.perspective_bounding_box, F.perspective_bounding_box,
...@@ -1418,6 +1410,7 @@ KERNEL_INFOS.extend( ...@@ -1418,6 +1410,7 @@ KERNEL_INFOS.extend(
**float32_vs_uint8_pixel_difference(6, mae=True), **float32_vs_uint8_pixel_difference(6, mae=True),
**cuda_vs_cpu_pixel_difference(), **cuda_vs_cpu_pixel_difference(),
}, },
test_marks=[xfail_jit_python_scalar_arg("fill")],
), ),
KernelInfo( KernelInfo(
F.elastic_bounding_box, F.elastic_bounding_box,
......
...@@ -118,7 +118,7 @@ class BoundingBox(Datapoint): ...@@ -118,7 +118,7 @@ class BoundingBox(Datapoint):
def pad( def pad(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: FillTypeJIT = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> BoundingBox: ) -> BoundingBox:
output, spatial_size = self._F.pad_bounding_box( output, spatial_size = self._F.pad_bounding_box(
......
...@@ -12,7 +12,7 @@ from torchvision.transforms import InterpolationMode ...@@ -12,7 +12,7 @@ from torchvision.transforms import InterpolationMode
D = TypeVar("D", bound="Datapoint") D = TypeVar("D", bound="Datapoint")
FillType = Union[int, float, Sequence[int], Sequence[float], None] FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None] FillTypeJIT = Optional[List[float]]
class Datapoint(torch.Tensor): class Datapoint(torch.Tensor):
...@@ -169,8 +169,8 @@ class Datapoint(torch.Tensor): ...@@ -169,8 +169,8 @@ class Datapoint(torch.Tensor):
def pad( def pad(
self, self,
padding: Union[int, List[int]], padding: List[int],
fill: FillTypeJIT = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Datapoint: ) -> Datapoint:
return self return self
......
...@@ -103,8 +103,8 @@ class Image(Datapoint): ...@@ -103,8 +103,8 @@ class Image(Datapoint):
def pad( def pad(
self, self,
padding: Union[int, List[int]], padding: List[int],
fill: FillTypeJIT = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Image: ) -> Image:
output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode) output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
......
...@@ -83,8 +83,8 @@ class Mask(Datapoint): ...@@ -83,8 +83,8 @@ class Mask(Datapoint):
def pad( def pad(
self, self,
padding: Union[int, List[int]], padding: List[int],
fill: FillTypeJIT = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Mask: ) -> Mask:
output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill) output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
......
...@@ -102,8 +102,8 @@ class Video(Datapoint): ...@@ -102,8 +102,8 @@ class Video(Datapoint):
def pad( def pad(
self, self,
padding: Union[int, List[int]], padding: List[int],
fill: FillTypeJIT = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Video: ) -> Video:
output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode) output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
......
...@@ -270,7 +270,7 @@ class Pad(Transform): ...@@ -270,7 +270,7 @@ class Pad(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
class RandomZoomOut(_RandomApplyTransform): class RandomZoomOut(_RandomApplyTransform):
......
...@@ -60,10 +60,9 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT: ...@@ -60,10 +60,9 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT:
if fill is None: if fill is None:
return fill return fill
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)): if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)] fill = [float(v) for v in list(fill)]
return fill return fill # type: ignore[return-value]
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]: def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]:
......
...@@ -432,7 +432,7 @@ def _apply_grid_transform( ...@@ -432,7 +432,7 @@ def _apply_grid_transform(
if fill is not None: if fill is not None:
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3) float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
mask = mask.expand_as(float_img) mask = mask.expand_as(float_img)
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
if mode == "nearest": if mode == "nearest":
bool_mask = mask < 0.5 bool_mask = mask < 0.5
...@@ -968,8 +968,8 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: ...@@ -968,8 +968,8 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
def pad_image_tensor( def pad_image_tensor(
image: torch.Tensor, image: torch.Tensor,
padding: Union[int, List[int]], padding: List[int],
fill: datapoints.FillTypeJIT = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
...@@ -1069,14 +1069,14 @@ pad_image_pil = _FP.pad ...@@ -1069,14 +1069,14 @@ pad_image_pil = _FP.pad
def pad_mask( def pad_mask(
mask: torch.Tensor, mask: torch.Tensor,
padding: Union[int, List[int]], padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
fill: datapoints.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
if fill is None: if fill is None:
fill = 0 fill = 0
if isinstance(fill, list): if isinstance(fill, (tuple, list)):
raise ValueError("Non-scalar fill value is not supported") raise ValueError("Non-scalar fill value is not supported")
if mask.ndim < 3: if mask.ndim < 3:
...@@ -1097,7 +1097,7 @@ def pad_bounding_box( ...@@ -1097,7 +1097,7 @@ def pad_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int], spatial_size: Tuple[int, int],
padding: Union[int, List[int]], padding: List[int],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
if padding_mode not in ["constant"]: if padding_mode not in ["constant"]:
...@@ -1122,8 +1122,8 @@ def pad_bounding_box( ...@@ -1122,8 +1122,8 @@ def pad_bounding_box(
def pad_video( def pad_video(
video: torch.Tensor, video: torch.Tensor,
padding: Union[int, List[int]], padding: List[int],
fill: datapoints.FillTypeJIT = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)
...@@ -1131,8 +1131,8 @@ def pad_video( ...@@ -1131,8 +1131,8 @@ def pad_video(
def pad( def pad(
inpt: datapoints.InputTypeJIT, inpt: datapoints.InputTypeJIT,
padding: Union[int, List[int]], padding: List[int],
fill: datapoints.FillTypeJIT = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment