Unverified Commit e1f464bd authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Minor speed and nit optimizations on Transform Classes (#6837)

* Change random generator for ColorJitter.

* Move `_convert_fill_arg` from runtime to constructor.

* Remove unnecessary TypeVars.

* Remove unnecessary casts

* Update comments.

* Minor code-quality changes on Geometical Transforms.

* Fixing linter and other minor fixes.

* Change mitigation for mypy.`

* Fixing the tests.

* Fixing the tests.

* Fix linter

* Restore dict copy.

* Handling of defaultdicts

* restore int idiom

* Update todo
parent add75968
...@@ -389,7 +389,7 @@ class TestPad: ...@@ -389,7 +389,7 @@ class TestPad:
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=features.Image)
_ = transform(inpt) _ = transform(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
if isinstance(padding, tuple): if isinstance(padding, tuple):
padding = list(padding) padding = list(padding)
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
...@@ -405,14 +405,14 @@ class TestPad: ...@@ -405,14 +405,14 @@ class TestPad:
_ = transform(inpt) _ = transform(inpt)
if isinstance(fill, int): if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
calls = [ calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"), mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"), mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
] ]
else: else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)]) fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)]) fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [ calls = [
mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"), mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"), mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
...@@ -466,7 +466,7 @@ class TestRandomZoomOut: ...@@ -466,7 +466,7 @@ class TestRandomZoomOut:
torch.rand(1) # random apply changes random state torch.rand(1) # random apply changes random state
params = transform._get_params([inpt]) params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill) fn.assert_called_once_with(inpt, **params, fill=fill)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
...@@ -485,14 +485,14 @@ class TestRandomZoomOut: ...@@ -485,14 +485,14 @@ class TestRandomZoomOut:
params = transform._get_params(inpt) params = transform._get_params(inpt)
if isinstance(fill, int): if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
calls = [ calls = [
mocker.call(image, **params, fill=fill), mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=fill), mocker.call(mask, **params, fill=fill),
] ]
else: else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)]) fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)]) fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [ calls = [
mocker.call(image, **params, fill=fill_img), mocker.call(image, **params, fill=fill_img),
mocker.call(mask, **params, fill=fill_mask), mocker.call(mask, **params, fill=fill_mask),
...@@ -556,7 +556,7 @@ class TestRandomRotation: ...@@ -556,7 +556,7 @@ class TestRandomRotation:
torch.manual_seed(12) torch.manual_seed(12)
params = transform._get_params(inpt) params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)
@pytest.mark.parametrize("angle", [34, -87]) @pytest.mark.parametrize("angle", [34, -87])
...@@ -694,7 +694,7 @@ class TestRandomAffine: ...@@ -694,7 +694,7 @@ class TestRandomAffine:
torch.manual_seed(12) torch.manual_seed(12)
params = transform._get_params([inpt]) params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
...@@ -939,7 +939,7 @@ class TestRandomPerspective: ...@@ -939,7 +939,7 @@ class TestRandomPerspective:
torch.rand(1) # random apply changes random state torch.rand(1) # random apply changes random state
params = transform._get_params([inpt]) params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
...@@ -1009,7 +1009,7 @@ class TestElasticTransform: ...@@ -1009,7 +1009,7 @@ class TestElasticTransform:
transform._get_params = mocker.MagicMock() transform._get_params = mocker.MagicMock()
_ = transform(inpt) _ = transform(inpt)
params = transform._get_params([inpt]) params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
...@@ -1632,7 +1632,7 @@ class TestFixedSizeCrop: ...@@ -1632,7 +1632,7 @@ class TestFixedSizeCrop:
if not needs_crop: if not needs_crop:
assert args[0] is inpt_sentinel assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel assert args[1] is padding_sentinel
fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel) fill_sentinel = transforms._utils._convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else: else:
mock_pad.assert_not_called() mock_pad.assert_not_called()
......
...@@ -983,8 +983,6 @@ class PadIfSmaller(prototype_transforms.Transform): ...@@ -983,8 +983,6 @@ class PadIfSmaller(prototype_transforms.Transform):
return inpt return inpt
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=params["padding"], fill=fill) return F.pad(inpt, padding=params["padding"], fill=fill)
......
import math import math
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -11,9 +11,6 @@ from torchvision.prototype.transforms.functional._meta import get_spatial_size ...@@ -11,9 +11,6 @@ from torchvision.prototype.transforms.functional._meta import get_spatial_size
from ._utils import _isinstance, _setup_fill_arg from ._utils import _isinstance, _setup_fill_arg
K = TypeVar("K")
V = TypeVar("V")
class _AutoAugmentBase(Transform): class _AutoAugmentBase(Transform):
def __init__( def __init__(
...@@ -26,7 +23,7 @@ class _AutoAugmentBase(Transform): ...@@ -26,7 +23,7 @@ class _AutoAugmentBase(Transform):
self.interpolation = interpolation self.interpolation = interpolation
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
keys = tuple(dct.keys()) keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))] key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key] return key, dct[key]
...@@ -71,10 +68,9 @@ class _AutoAugmentBase(Transform): ...@@ -71,10 +68,9 @@ class _AutoAugmentBase(Transform):
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: InterpolationMode, interpolation: InterpolationMode,
fill: Dict[Type, features.FillType], fill: Dict[Type, features.FillTypeJIT],
) -> Union[features.ImageType, features.VideoType]: ) -> Union[features.ImageType, features.VideoType]:
fill_ = fill[type(image)] fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_)
if transform_id == "Identity": if transform_id == "Identity":
return image return image
...@@ -170,9 +166,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -170,9 +166,7 @@ class AutoAugment(_AutoAugmentBase):
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": ( "Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
.round()
.int(),
False, False,
), ),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
...@@ -327,9 +321,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -327,9 +321,7 @@ class RandAugment(_AutoAugmentBase):
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": ( "Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
.round()
.int(),
False, False,
), ),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
...@@ -383,9 +375,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -383,9 +375,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": ( "Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
.round()
.int(),
False, False,
), ),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
...@@ -430,9 +420,7 @@ class AugMix(_AutoAugmentBase): ...@@ -430,9 +420,7 @@ class AugMix(_AutoAugmentBase):
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True), "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": ( "Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
.round()
.int(),
False, False,
), ),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
...@@ -517,7 +505,13 @@ class AugMix(_AutoAugmentBase): ...@@ -517,7 +505,13 @@ class AugMix(_AutoAugmentBase):
aug = self._apply_image_or_video_transform( aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix.add_(
# The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()`
# Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`.
# TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840
combined_weights[:, i].reshape(batch_dims)
* aug
)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (features.Image, features.Video)): if isinstance(orig_image_or_video, (features.Image, features.Video)):
......
...@@ -51,7 +51,7 @@ class ColorJitter(Transform): ...@@ -51,7 +51,7 @@ class ColorJitter(Transform):
@staticmethod @staticmethod
def _generate_value(left: float, right: float) -> float: def _generate_value(left: float, right: float) -> float:
return float(torch.distributions.Uniform(left, right).sample()) return torch.empty(1).uniform_(left, right).item()
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4) fn_idx = torch.randperm(4)
......
...@@ -223,20 +223,16 @@ class Pad(Transform): ...@@ -223,20 +223,16 @@ class Pad(Transform):
_check_padding_arg(padding) _check_padding_arg(padding)
_check_padding_mode_arg(padding_mode) _check_padding_mode_arg(padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
self.padding = padding self.padding = padding
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
padding = self.padding
if not isinstance(padding, int):
padding = list(padding)
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
class RandomZoomOut(_RandomApplyTransform): class RandomZoomOut(_RandomApplyTransform):
...@@ -274,7 +270,6 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -274,7 +270,6 @@ class RandomZoomOut(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, **params, fill=fill) return F.pad(inpt, **params, fill=fill)
...@@ -300,12 +295,11 @@ class RandomRotation(Transform): ...@@ -300,12 +295,11 @@ class RandomRotation(Transform):
self.center = center self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
return dict(angle=angle) return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.rotate( return F.rotate(
inpt, inpt,
**params, **params,
...@@ -358,7 +352,7 @@ class RandomAffine(Transform): ...@@ -358,7 +352,7 @@ class RandomAffine(Transform):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs) height, width = query_spatial_size(flat_inputs)
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
if self.translate is not None: if self.translate is not None:
max_dx = float(self.translate[0] * width) max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height) max_dy = float(self.translate[1] * height)
...@@ -369,22 +363,21 @@ class RandomAffine(Transform): ...@@ -369,22 +363,21 @@ class RandomAffine(Transform):
translate = (0, 0) translate = (0, 0)
if self.scale is not None: if self.scale is not None:
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()) scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
else: else:
scale = 1.0 scale = 1.0
shear_x = shear_y = 0.0 shear_x = shear_y = 0.0
if self.shear is not None: if self.shear is not None:
shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()) shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
if len(self.shear) == 4: if len(self.shear) == 4:
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()) shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
shear = (shear_x, shear_y) shear = (shear_x, shear_y)
return dict(angle=angle, translate=translate, scale=scale, shear=shear) return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.affine( return F.affine(
inpt, inpt,
**params, **params,
...@@ -478,8 +471,6 @@ class RandomCrop(Transform): ...@@ -478,8 +471,6 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]: if params["needs_pad"]:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]: if params["needs_crop"]:
...@@ -512,21 +503,23 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -512,21 +503,23 @@ class RandomPerspective(_RandomApplyTransform):
half_height = height // 2 half_height = height // 2
half_width = width // 2 half_width = width // 2
bound_height = int(distortion_scale * half_height) + 1
bound_width = int(distortion_scale * half_width) + 1
topleft = [ topleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), int(torch.randint(0, bound_height, size=(1,))),
] ]
topright = [ topright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), int(torch.randint(0, bound_height, size=(1,))),
] ]
botright = [ botright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), int(torch.randint(height - bound_height, height, size=(1,))),
] ]
botleft = [ botleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), int(torch.randint(height - bound_height, height, size=(1,))),
] ]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft] endpoints = [topleft, topright, botright, botleft]
...@@ -535,7 +528,6 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -535,7 +528,6 @@ class RandomPerspective(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.perspective( return F.perspective(
inpt, inpt,
**params, **params,
...@@ -584,7 +576,6 @@ class ElasticTransform(Transform): ...@@ -584,7 +576,6 @@ class ElasticTransform(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.elastic( return F.elastic(
inpt, inpt,
**params, **params,
...@@ -855,7 +846,6 @@ class FixedSizeCrop(Transform): ...@@ -855,7 +846,6 @@ class FixedSizeCrop(Transform):
if params["needs_pad"]: if params["needs_pad"]:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt return inpt
......
from typing import Any, cast, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -13,7 +13,7 @@ class DecodeImage(Transform): ...@@ -13,7 +13,7 @@ class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,) _transformed_types = (features.EncodedImage,)
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image: def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
return cast(features.Image, F.decode_image_with_pil(inpt)) return F.decode_image_with_pil(inpt) # type: ignore[no-any-return]
class LabelToOneHot(Transform): class LabelToOneHot(Transform):
...@@ -27,7 +27,7 @@ class LabelToOneHot(Transform): ...@@ -27,7 +27,7 @@ class LabelToOneHot(Transform):
num_categories = self.num_categories num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None: if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories) num_categories = len(inpt.categories)
output = one_hot(inpt, num_classes=num_categories) output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return features.OneHotLabel(output, categories=inpt.categories) return features.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str: def extra_repr(self) -> str:
...@@ -50,7 +50,7 @@ class ToImageTensor(Transform): ...@@ -50,7 +50,7 @@ class ToImageTensor(Transform):
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image: ) -> features.Image:
return cast(features.Image, F.to_image_tensor(inpt)) return F.to_image_tensor(inpt) # type: ignore[no-any-return]
class ToImagePIL(Transform): class ToImagePIL(Transform):
......
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType from torchvision.prototype.features._feature import FillType, FillTypeJIT
from torchvision.prototype.transforms.functional._meta import get_dimensions, get_spatial_size from torchvision.prototype.transforms.functional._meta import get_dimensions, get_spatial_size
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
...@@ -37,9 +37,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: ...@@ -37,9 +37,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
for key, value in fill.items(): for key, value in fill.items():
# Check key for type # Check key for type
_check_fill_arg(value) _check_fill_arg(value)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
_check_fill_arg(default_value)
else: else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg") raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
T = TypeVar("T") T = TypeVar("T")
...@@ -55,13 +58,33 @@ def _get_defaultdict(default: T) -> Dict[Any, T]: ...@@ -55,13 +58,33 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
return defaultdict(functools.partial(_default_arg, default)) return defaultdict(functools.partial(_default_arg, default))
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
return fill
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]:
_check_fill_arg(fill) _check_fill_arg(fill)
if isinstance(fill, dict): if isinstance(fill, dict):
return fill for k, v in fill.items():
fill[k] = _convert_fill_arg(v)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
sanitized_default = _convert_fill_arg(default_value)
fill.default_factory = functools.partial(_default_arg, sanitized_default)
return fill # type: ignore[return-value]
return _get_defaultdict(fill) return _get_defaultdict(_convert_fill_arg(fill))
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
...@@ -80,7 +103,7 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", ...@@ -80,7 +103,7 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect",
def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox: def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox:
bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)} bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)]
if not bounding_boxes: if not bounding_boxes:
raise TypeError("No bounding box was found in the sample") raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 1: elif len(bounding_boxes) > 1:
......
...@@ -470,20 +470,6 @@ def affine_video( ...@@ -470,20 +470,6 @@ def affine_video(
) )
def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
return fill
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill
def affine( def affine(
inpt: features.InputTypeJIT, inpt: features.InputTypeJIT,
angle: Union[int, float], angle: Union[int, float],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment