Unverified Commit 2a5fbcdd authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixed F.perspective signature (#6617)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent cffb7f7f
...@@ -113,21 +113,14 @@ DISPATCHER_INFOS = [ ...@@ -113,21 +113,14 @@ DISPATCHER_INFOS = [
features.Mask: F.pad_mask, features.Mask: F.pad_mask,
}, },
), ),
# FIXME: DispatcherInfo(
# RuntimeError: perspective() is missing value for argument 'startpoints'. F.perspective,
# Declaration: perspective(Tensor inpt, int[][] startpoints, int[][] endpoints, kernels={
# Enum<__torch__.torchvision.transforms.functional.InterpolationMode> interpolation=Enum<InterpolationMode.BILINEAR>, features.Image: F.perspective_image_tensor,
# Union(float[], float, int, NoneType) fill=None) -> Tensor features.BoundingBox: F.perspective_bounding_box,
# features.Mask: F.perspective_mask,
# This is probably due to the fact that F.perspective does not have the same signature as F.perspective_image_tensor },
# DispatcherInfo( ),
# F.perspective,
# kernels={
# features.Image: F.perspective_image_tensor,
# features.BoundingBox: F.perspective_bounding_box,
# features.Mask: F.perspective_mask,
# },
# ),
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
kernels={ kernels={
......
...@@ -894,21 +894,8 @@ class TestRandomPerspective: ...@@ -894,21 +894,8 @@ class TestRandomPerspective:
params = transform._get_params(image) params = transform._get_params(image)
h, w = image.image_size h, w = image.image_size
assert len(params["startpoints"]) == 4 assert "perspective_coeffs" in params
for x, y in params["startpoints"]: assert len(params["perspective_coeffs"]) == 8
assert x in (0, w - 1)
assert y in (0, h - 1)
assert len(params["endpoints"]) == 4
for (x, y), name in zip(params["endpoints"], ["tl", "tr", "br", "bl"]):
if "t" in name:
assert 0 <= y <= int(dscale * h // 2), (x, y, name)
if "b" in name:
assert h - int(dscale * h // 2) - 1 <= y <= h, (x, y, name)
if "l" in name:
assert 0 <= x <= int(dscale * w // 2), (x, y, name)
if "r" in name:
assert w - int(dscale * w // 2) - 1 <= x <= w, (x, y, name)
@pytest.mark.parametrize("distortion_scale", [0.1, 0.7]) @pytest.mark.parametrize("distortion_scale", [0.1, 0.7])
def test__transform(self, distortion_scale, mocker): def test__transform(self, distortion_scale, mocker):
......
...@@ -1232,7 +1232,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1232,7 +1232,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
np.max(transformed_points[:, 1]), np.max(transformed_points[:, 1]),
] ]
out_bbox = features.BoundingBox( out_bbox = features.BoundingBox(
out_bbox, np.array(out_bbox),
format=features.BoundingBoxFormat.XYXY, format=features.BoundingBoxFormat.XYXY,
image_size=bbox.image_size, image_size=bbox.image_size,
dtype=bbox.dtype, dtype=bbox.dtype,
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.transforms.functional import _get_perspective_coeffs
from typing_extensions import Literal from typing_extensions import Literal
...@@ -556,7 +557,8 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -556,7 +557,8 @@ class RandomPerspective(_RandomApplyTransform):
] ]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft] endpoints = [topleft, topright, botright, botleft]
return dict(startpoints=startpoints, endpoints=endpoints) perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
return dict(perspective_coeffs=perspective_coeffs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
......
...@@ -9,7 +9,6 @@ from torchvision.transforms import functional_pil as _FP, functional_tensor as _ ...@@ -9,7 +9,6 @@ from torchvision.transforms import functional_pil as _FP, functional_tensor as _
from torchvision.transforms.functional import ( from torchvision.transforms.functional import (
_compute_resized_output_size, _compute_resized_output_size,
_get_inverse_affine_matrix, _get_inverse_affine_matrix,
_get_perspective_coeffs,
InterpolationMode, InterpolationMode,
pil_modes_mapping, pil_modes_mapping,
pil_to_tensor, pil_to_tensor,
...@@ -876,13 +875,10 @@ def perspective_mask( ...@@ -876,13 +875,10 @@ def perspective_mask(
def perspective( def perspective(
inpt: features.DType, inpt: features.DType,
startpoints: List[List[int]], perspective_coeffs: List[float],
endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, List[float]]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> features.DType: ) -> features.DType:
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, features._Feature):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment