Unverified Commit c0911e31 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Update typehint for fill arg in rotate (#6594)

parent 753bf186
......@@ -102,18 +102,20 @@ def affine_mask():
@register_kernel_info_from_sample_inputs_fn
def rotate_image_tensor():
for image, angle, expand, center, fill in itertools.product(
for image, angle, expand, center in itertools.product(
make_images(),
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
[None, [128], [12.0]], # fill
):
if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True
continue
yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=fill)
yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=None)
for fill in [None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]]:
yield ArgsKwargs(image, angle=23, expand=False, center=None, fill=fill)
@register_kernel_info_from_sample_inputs_fn
......
......@@ -467,7 +467,7 @@ def rotate_image_tensor(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
num_channels, height, width = img.shape[-3:]
......
......@@ -475,7 +475,7 @@ def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[List[float]],
fill: Optional[Union[int, float, List[float]]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
......@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs(
# Check fill
num_channels = get_dimensions(img)[0]
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
if fill is not None and isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
......@@ -539,7 +539,9 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
return img
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor:
def _apply_grid_transform(
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
......@@ -559,8 +561,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
......@@ -648,7 +650,7 @@ def rotate(
matrix: List[float],
interpolation: str = "nearest",
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment