Unverified Commit e1f464bd authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Minor speed and nit optimizations on Transform Classes (#6837)

* Change random generator for ColorJitter.

* Move `_convert_fill_arg` from runtime to constructor.

* Remove unnecessary TypeVars.

* Remove unnecessary casts

* Update comments.

* Minor code-quality changes on Geometical Transforms.

* Fixing linter and other minor fixes.

* Change mitigation for mypy.`

* Fixing the tests.

* Fixing the tests.

* Fix linter

* Restore dict copy.

* Handling of defaultdicts

* restore int idiom

* Update todo
parent add75968
......@@ -389,7 +389,7 @@ class TestPad:
inpt = mocker.MagicMock(spec=features.Image)
_ = transform(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
if isinstance(padding, tuple):
padding = list(padding)
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
......@@ -405,14 +405,14 @@ class TestPad:
_ = transform(inpt)
if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
]
else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
......@@ -466,7 +466,7 @@ class TestRandomZoomOut:
torch.rand(1) # random apply changes random state
params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
......@@ -485,14 +485,14 @@ class TestRandomZoomOut:
params = transform._get_params(inpt)
if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=fill),
]
else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, **params, fill=fill_img),
mocker.call(mask, **params, fill=fill_mask),
......@@ -556,7 +556,7 @@ class TestRandomRotation:
torch.manual_seed(12)
params = transform._get_params(inpt)
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)
@pytest.mark.parametrize("angle", [34, -87])
......@@ -694,7 +694,7 @@ class TestRandomAffine:
torch.manual_seed(12)
params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
......@@ -939,7 +939,7 @@ class TestRandomPerspective:
torch.rand(1) # random apply changes random state
params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
......@@ -1009,7 +1009,7 @@ class TestElasticTransform:
transform._get_params = mocker.MagicMock()
_ = transform(inpt)
params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
......@@ -1632,7 +1632,7 @@ class TestFixedSizeCrop:
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel)
fill_sentinel = transforms._utils._convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()
......
......@@ -983,8 +983,6 @@ class PadIfSmaller(prototype_transforms.Transform):
return inpt
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=params["padding"], fill=fill)
......
import math
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import PIL.Image
import torch
......@@ -11,9 +11,6 @@ from torchvision.prototype.transforms.functional._meta import get_spatial_size
from ._utils import _isinstance, _setup_fill_arg
K = TypeVar("K")
V = TypeVar("V")
class _AutoAugmentBase(Transform):
def __init__(
......@@ -26,7 +23,7 @@ class _AutoAugmentBase(Transform):
self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
......@@ -71,10 +68,9 @@ class _AutoAugmentBase(Transform):
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Dict[Type, features.FillType],
fill: Dict[Type, features.FillTypeJIT],
) -> Union[features.ImageType, features.VideoType]:
fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_)
if transform_id == "Identity":
return image
......@@ -170,9 +166,7 @@ class AutoAugment(_AutoAugmentBase):
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
......@@ -327,9 +321,7 @@ class RandAugment(_AutoAugmentBase):
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
......@@ -383,9 +375,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
.round()
.int(),
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
......@@ -430,9 +420,7 @@ class AugMix(_AutoAugmentBase):
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
......@@ -517,7 +505,13 @@ class AugMix(_AutoAugmentBase):
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix.add_(
# The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()`
# Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`.
# TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840
combined_weights[:, i].reshape(batch_dims)
* aug
)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (features.Image, features.Video)):
......
......@@ -51,7 +51,7 @@ class ColorJitter(Transform):
@staticmethod
def _generate_value(left: float, right: float) -> float:
return float(torch.distributions.Uniform(left, right).sample())
return torch.empty(1).uniform_(left, right).item()
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4)
......
......@@ -223,20 +223,16 @@ class Pad(Transform):
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
self.padding = padding
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
padding = self.padding
if not isinstance(padding, int):
padding = list(padding)
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
class RandomZoomOut(_RandomApplyTransform):
......@@ -274,7 +270,6 @@ class RandomZoomOut(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, **params, fill=fill)
......@@ -300,12 +295,11 @@ class RandomRotation(Transform):
self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.rotate(
inpt,
**params,
......@@ -358,7 +352,7 @@ class RandomAffine(Transform):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
......@@ -369,22 +363,21 @@ class RandomAffine(Transform):
translate = (0, 0)
if self.scale is not None:
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item())
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
else:
scale = 1.0
shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item())
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
if len(self.shear) == 4:
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item())
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
shear = (shear_x, shear_y)
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.affine(
inpt,
**params,
......@@ -478,8 +471,6 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]:
......@@ -512,21 +503,23 @@ class RandomPerspective(_RandomApplyTransform):
half_height = height // 2
half_width = width // 2
bound_height = int(distortion_scale * half_height) + 1
bound_width = int(distortion_scale * half_width) + 1
topleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
topright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
botright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
botleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
......@@ -535,7 +528,6 @@ class RandomPerspective(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.perspective(
inpt,
**params,
......@@ -584,7 +576,6 @@ class ElasticTransform(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.elastic(
inpt,
**params,
......@@ -855,7 +846,6 @@ class FixedSizeCrop(Transform):
if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
......
from typing import Any, cast, Dict, Optional, Union
from typing import Any, Dict, Optional, Union
import numpy as np
import PIL.Image
......@@ -13,7 +13,7 @@ class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,)
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
return cast(features.Image, F.decode_image_with_pil(inpt))
return F.decode_image_with_pil(inpt) # type: ignore[no-any-return]
class LabelToOneHot(Transform):
......@@ -27,7 +27,7 @@ class LabelToOneHot(Transform):
num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories)
output = one_hot(inpt, num_classes=num_categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return features.OneHotLabel(output, categories=inpt.categories)
def extra_repr(self) -> str:
......@@ -50,7 +50,7 @@ class ToImageTensor(Transform):
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image:
return cast(features.Image, F.to_image_tensor(inpt))
return F.to_image_tensor(inpt) # type: ignore[no-any-return]
class ToImagePIL(Transform):
......
......@@ -7,7 +7,7 @@ import PIL.Image
from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType
from torchvision.prototype.features._feature import FillType, FillTypeJIT
from torchvision.prototype.transforms.functional._meta import get_dimensions, get_spatial_size
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
......@@ -37,9 +37,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
_check_fill_arg(default_value)
else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
T = TypeVar("T")
......@@ -55,13 +58,33 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
return defaultdict(functools.partial(_default_arg, default))
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
return fill
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]:
_check_fill_arg(fill)
if isinstance(fill, dict):
return fill
for k, v in fill.items():
fill[k] = _convert_fill_arg(v)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
sanitized_default = _convert_fill_arg(default_value)
fill.default_factory = functools.partial(_default_arg, sanitized_default)
return fill # type: ignore[return-value]
return _get_defaultdict(fill)
return _get_defaultdict(_convert_fill_arg(fill))
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
......@@ -80,7 +103,7 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect",
def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox:
bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)}
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)]
if not bounding_boxes:
raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 1:
......
......@@ -470,20 +470,6 @@ def affine_video(
)
def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
return fill
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill
def affine(
inpt: features.InputTypeJIT,
angle: Union[int, float],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment