Unverified Commit 0de3e5b4 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Aligned fill, padding typehints between Features and F (#6616)

parent 841b9a19
...@@ -115,7 +115,7 @@ class BoundingBox(_Feature): ...@@ -115,7 +115,7 @@ class BoundingBox(_Feature):
def pad( def pad(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> BoundingBox: ) -> BoundingBox:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy # This cast does Sequence[int] -> List[int] and is required to make mypy happy
...@@ -137,7 +137,7 @@ class BoundingBox(_Feature): ...@@ -137,7 +137,7 @@ class BoundingBox(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.rotate_bounding_box( output = self._F.rotate_bounding_box(
...@@ -165,7 +165,7 @@ class BoundingBox(_Feature): ...@@ -165,7 +165,7 @@ class BoundingBox(_Feature):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.affine_bounding_box( output = self._F.affine_bounding_box(
...@@ -184,7 +184,7 @@ class BoundingBox(_Feature): ...@@ -184,7 +184,7 @@ class BoundingBox(_Feature):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype) return BoundingBox.new_like(self, output, dtype=output.dtype)
...@@ -193,7 +193,7 @@ class BoundingBox(_Feature): ...@@ -193,7 +193,7 @@ class BoundingBox(_Feature):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.elastic_bounding_box(self, self.format, displacement) output = self._F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype) return BoundingBox.new_like(self, output, dtype=output.dtype)
...@@ -153,8 +153,8 @@ class _Feature(torch.Tensor): ...@@ -153,8 +153,8 @@ class _Feature(torch.Tensor):
def pad( def pad(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> _Feature: ) -> _Feature:
return self return self
...@@ -164,7 +164,7 @@ class _Feature(torch.Tensor): ...@@ -164,7 +164,7 @@ class _Feature(torch.Tensor):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> _Feature: ) -> _Feature:
return self return self
...@@ -176,7 +176,7 @@ class _Feature(torch.Tensor): ...@@ -176,7 +176,7 @@ class _Feature(torch.Tensor):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> _Feature: ) -> _Feature:
return self return self
...@@ -185,7 +185,7 @@ class _Feature(torch.Tensor): ...@@ -185,7 +185,7 @@ class _Feature(torch.Tensor):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> _Feature: ) -> _Feature:
return self return self
...@@ -193,7 +193,7 @@ class _Feature(torch.Tensor): ...@@ -193,7 +193,7 @@ class _Feature(torch.Tensor):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> _Feature: ) -> _Feature:
return self return self
......
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any, cast, List, Optional, Sequence, Tuple, Union from typing import Any, cast, List, Optional, Tuple, Union
import torch import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
...@@ -180,16 +180,10 @@ class Image(_Feature): ...@@ -180,16 +180,10 @@ class Image(_Feature):
def pad( def pad(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Image: ) -> Image:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
fill = self._F._geometry._convert_fill_arg(fill)
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
return Image.new_like(self, output) return Image.new_like(self, output)
...@@ -198,11 +192,9 @@ class Image(_Feature): ...@@ -198,11 +192,9 @@ class Image(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Image: ) -> Image:
fill = self._F._geometry._convert_fill_arg(fill)
output = self._F._geometry.rotate_image_tensor( output = self._F._geometry.rotate_image_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
) )
...@@ -215,11 +207,9 @@ class Image(_Feature): ...@@ -215,11 +207,9 @@ class Image(_Feature):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Image: ) -> Image:
fill = self._F._geometry._convert_fill_arg(fill)
output = self._F._geometry.affine_image_tensor( output = self._F._geometry.affine_image_tensor(
self, self,
angle, angle,
...@@ -236,10 +226,8 @@ class Image(_Feature): ...@@ -236,10 +226,8 @@ class Image(_Feature):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> Image: ) -> Image:
fill = self._F._geometry._convert_fill_arg(fill)
output = self._F._geometry.perspective_image_tensor( output = self._F._geometry.perspective_image_tensor(
self, perspective_coeffs, interpolation=interpolation, fill=fill self, perspective_coeffs, interpolation=interpolation, fill=fill
) )
...@@ -249,10 +237,8 @@ class Image(_Feature): ...@@ -249,10 +237,8 @@ class Image(_Feature):
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> Image: ) -> Image:
fill = self._F._geometry._convert_fill_arg(fill)
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Image.new_like(self, output) return Image.new_like(self, output)
......
from __future__ import annotations from __future__ import annotations
from typing import List, Optional, Sequence, Union from typing import List, Optional, Union
import torch import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
...@@ -50,16 +50,10 @@ class Mask(_Feature): ...@@ -50,16 +50,10 @@ class Mask(_Feature):
def pad( def pad(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Mask: ) -> Mask:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
fill = self._F._geometry._convert_fill_arg(fill)
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
return Mask.new_like(self, output) return Mask.new_like(self, output)
...@@ -68,10 +62,10 @@ class Mask(_Feature): ...@@ -68,10 +62,10 @@ class Mask(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Mask: ) -> Mask:
output = self._F.rotate_mask(self, angle, expand=expand, center=center) output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
return Mask.new_like(self, output) return Mask.new_like(self, output)
def affine( def affine(
...@@ -81,7 +75,7 @@ class Mask(_Feature): ...@@ -81,7 +75,7 @@ class Mask(_Feature):
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Mask: ) -> Mask:
output = self._F.affine_mask( output = self._F.affine_mask(
...@@ -90,6 +84,7 @@ class Mask(_Feature): ...@@ -90,6 +84,7 @@ class Mask(_Feature):
translate=translate, translate=translate,
scale=scale, scale=scale,
shear=shear, shear=shear,
fill=fill,
center=center, center=center,
) )
return Mask.new_like(self, output) return Mask.new_like(self, output)
...@@ -98,16 +93,16 @@ class Mask(_Feature): ...@@ -98,16 +93,16 @@ class Mask(_Feature):
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> Mask: ) -> Mask:
output = self._F.perspective_mask(self, perspective_coeffs) output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
return Mask.new_like(self, output) return Mask.new_like(self, output)
def elastic( def elastic(
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> Mask: ) -> Mask:
output = self._F.elastic_mask(self, displacement) output = self._F.elastic_mask(self, displacement, fill=fill)
return Mask.new_like(self, output, dtype=output.dtype) return Mask.new_like(self, output, dtype=output.dtype)
...@@ -379,6 +379,7 @@ def affine_mask( ...@@ -379,6 +379,7 @@ def affine_mask(
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
...@@ -394,6 +395,7 @@ def affine_mask( ...@@ -394,6 +395,7 @@ def affine_mask(
scale=scale, scale=scale,
shear=shear, shear=shear,
interpolation=InterpolationMode.NEAREST, interpolation=InterpolationMode.NEAREST,
fill=fill,
center=center, center=center,
) )
...@@ -541,6 +543,7 @@ def rotate_mask( ...@@ -541,6 +543,7 @@ def rotate_mask(
mask: torch.Tensor, mask: torch.Tensor,
angle: float, angle: float,
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
...@@ -554,6 +557,7 @@ def rotate_mask( ...@@ -554,6 +557,7 @@ def rotate_mask(
angle=angle, angle=angle,
expand=expand, expand=expand,
interpolation=InterpolationMode.NEAREST, interpolation=InterpolationMode.NEAREST,
fill=fill,
center=center, center=center,
) )
...@@ -849,7 +853,11 @@ def perspective_bounding_box( ...@@ -849,7 +853,11 @@ def perspective_bounding_box(
).view(original_shape) ).view(original_shape)
def perspective_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: def perspective_mask(
mask: torch.Tensor,
perspective_coeffs: List[float],
fill: Optional[Union[int, float, List[float]]] = None,
) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
...@@ -857,7 +865,7 @@ def perspective_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> tor ...@@ -857,7 +865,7 @@ def perspective_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> tor
needs_squeeze = False needs_squeeze = False
output = perspective_image_tensor( output = perspective_image_tensor(
mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST, fill=fill
) )
if needs_squeeze: if needs_squeeze:
...@@ -944,14 +952,18 @@ def elastic_bounding_box( ...@@ -944,14 +952,18 @@ def elastic_bounding_box(
).view(original_shape) ).view(original_shape)
def elastic_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor: def elastic_mask(
mask: torch.Tensor,
displacement: torch.Tensor,
fill: Optional[Union[int, float, List[float]]] = None,
) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
else: else:
needs_squeeze = False needs_squeeze = False
output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST) output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment