Unverified Commit 7cc2c95a authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Consistent supported/unsupported types handling in...

[proto] Consistent supported/unsupported types handling in LinearTransformation, other perf comments (#6498)

* WIP

* [proto] unformal supported/unsupported types handling in LinearTransformation, other perf comments

* Type fixes and other minor stuff
parent becaba0e
......@@ -1118,6 +1118,15 @@ class TestContainers:
assert isinstance(output, torch.Tensor)
class TestRandomChoice:
def test_assertions(self):
with pytest.warns(UserWarning, match="Argument p is deprecated and will be removed"):
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1, 2])
with pytest.raises(ValueError, match="The number of probabilities doesn't match the number of transforms"):
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], probabilities=[1])
class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
......@@ -1616,7 +1625,7 @@ class TestLinearTransformation:
transform = transforms.LinearTransformation(m, v)
if isinstance(inpt, PIL.Image.Image):
with pytest.raises(TypeError, match="Unsupported input type"):
with pytest.raises(TypeError, match="LinearTransformation does not work on PIL Images"):
transform(inpt)
else:
output = transform(inpt)
......
......@@ -245,8 +245,10 @@ class SimpleCopyPaste(_RandomApplyTransform):
# Copy-paste boxes and labels
bbox_format = target["boxes"].format
xyxy_boxes = masks_to_boxes(masks)
# TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2. We need to investigate that.
# masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2.
# There is a similar +1 in other reference implementations:
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
......
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence
import torch
......@@ -33,9 +34,21 @@ class RandomApply(_RandomApplyTransform):
class RandomChoice(Transform):
def __init__(self, transforms: Sequence[Callable], probabilities: Optional[List[float]] = None) -> None:
def __init__(
self,
transforms: Sequence[Callable],
probabilities: Optional[List[float]] = None,
p: Optional[List[float]] = None,
) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is not None:
warnings.warn(
"Argument p is deprecated and will be removed in a future release. "
"Please use probabilities argument instead."
)
probabilities = p
if probabilities is None:
probabilities = [1] * len(transforms)
elif len(probabilities) != len(transforms):
......@@ -48,7 +61,7 @@ class RandomChoice(Transform):
self.transforms = transforms
total = sum(probabilities)
self.probabilities = [p / total for p in probabilities]
self.probabilities = [prob / total for prob in probabilities]
def forward(self, *inputs: Any) -> Any:
idx = int(torch.multinomial(torch.tensor(self.probabilities), 1))
......
......@@ -53,7 +53,7 @@ class ToPILImage(Transform):
super().__init__()
self.mode = mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
return _F.to_pil_image(inpt, mode=self.mode)
......
......@@ -462,6 +462,7 @@ class RandomCrop(Transform):
)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls
if self.padding is not None:
inpt = F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
......
......@@ -7,11 +7,9 @@ import torch
from torchvision.ops import remove_small_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.transforms._utils import query_bounding_box
from torchvision.prototype.transforms._utils import has_any, is_simple_tensor, query_bounding_box
from torchvision.transforms.transforms import _setup_size
from ._utils import is_simple_tensor
class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
......@@ -40,6 +38,8 @@ class Lambda(Transform):
class LinearTransformation(Transform):
_transformed_types = (is_simple_tensor, features.Image)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
......@@ -62,13 +62,14 @@ class LinearTransformation(Transform):
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
return inpt
elif isinstance(inpt, PIL.Image.Image):
raise TypeError("Unsupported input type")
return super().forward(sample)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
......
......@@ -22,7 +22,7 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif isinstance(image, torch.Tensor):
elif is_simple_tensor(image):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image)
......
......@@ -64,6 +64,8 @@ def convert_bounding_box_format(
def clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
# TODO: (PERF) Possible speed up clamping if we have different implementations for each bbox format.
# Not sure if they yield equivalent results.
xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY)
xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0])
......
......@@ -54,7 +54,9 @@ def gaussian_blur_image_tensor(
return _FT.gaussian_blur(img, kernel_size, sigma)
def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optional[List[float]] = None) -> PIL.Image:
def gaussian_blur_image_pil(
img: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image:
t_img = pil_to_tensor(img)
output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma)
return to_pil_image(output, mode=img.mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment