Unverified Commit c0911e31 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Update typehint for fill arg in rotate (#6594)

parent 753bf186
...@@ -102,18 +102,20 @@ def affine_mask(): ...@@ -102,18 +102,20 @@ def affine_mask():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def rotate_image_tensor(): def rotate_image_tensor():
for image, angle, expand, center, fill in itertools.product( for image, angle, expand, center in itertools.product(
make_images(), make_images(),
[-87, 15, 90], # angle [-87, 15, 90], # angle
[True, False], # expand [True, False], # expand
[None, [12, 23]], # center [None, [12, 23]], # center
[None, [128], [12.0]], # fill
): ):
if center is not None and expand: if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True # Skip warning: The provided center argument is ignored if expand is True
continue continue
yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=fill) yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=None)
for fill in [None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]]:
yield ArgsKwargs(image, angle=23, expand=False, center=None, fill=fill)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
......
...@@ -467,7 +467,7 @@ def rotate_image_tensor( ...@@ -467,7 +467,7 @@ def rotate_image_tensor(
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, height, width = img.shape[-3:] num_channels, height, width = img.shape[-3:]
......
...@@ -475,7 +475,7 @@ def _assert_grid_transform_inputs( ...@@ -475,7 +475,7 @@ def _assert_grid_transform_inputs(
img: Tensor, img: Tensor,
matrix: Optional[List[float]], matrix: Optional[List[float]],
interpolation: str, interpolation: str,
fill: Optional[List[float]], fill: Optional[Union[int, float, List[float]]],
supported_interpolation_modes: List[str], supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None, coeffs: Optional[List[float]] = None,
) -> None: ) -> None:
...@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs( ...@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs(
# Check fill # Check fill
num_channels = get_dimensions(img)[0] num_channels = get_dimensions(img)[0]
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): if fill is not None and isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = ( msg = (
"The number of elements in 'fill' cannot broadcast to match the number of " "The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})" "channels of the image ({} != {})"
...@@ -539,7 +539,9 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp ...@@ -539,7 +539,9 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
return img return img
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor: def _apply_grid_transform(
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype]) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
...@@ -559,8 +561,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L ...@@ -559,8 +561,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
mask = img[:, -1:, :, :] # N * 1 * H * W mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img) mask = mask.expand_as(img)
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == "nearest": if mode == "nearest":
mask = mask < 0.5 mask = mask < 0.5
img[mask] = fill_img[mask] img[mask] = fill_img[mask]
...@@ -648,7 +650,7 @@ def rotate( ...@@ -648,7 +650,7 @@ def rotate(
matrix: List[float], matrix: List[float],
interpolation: str = "nearest", interpolation: str = "nearest",
expand: bool = False, expand: bool = False,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor: ) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2] w, h = img.shape[-1], img.shape[-2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment