Unverified Commit 77c8c91c authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Ported all transforms to the new API (#6305)

* [proto] Added few transforms tests, part 1 (#6262)

* Added supported/unsupported data checks in the tests for cutmix/mixup

* Added RandomRotation, RandomAffine transforms tests

* Added tests for RandomZoomOut, Pad

* Update test_prototype_transforms.py

* Added RandomCrop transform and tests (#6271)

* [proto] Added GaussianBlur transform and tests (#6273)

* Added GaussianBlur transform and tests

* Fixing code format

* Copied correctness test

* [proto] Added random color transforms and tests (#6275)

* Added random color transforms and tests

* Disable smoke test for RandomSolarize, RandomAdjustSharpness

* Added RandomPerspective and tests (#6284)

- replaced real image creation by mocks for other tests

* Added more functional tests (#6285)

* [proto] Added elastic transform and tests (#6295)

* WIP [proto] Added functional elastic transform with tests

* Added more functional tests

* WIP on elastic op

* Added elastic transform and tests

* Added tests

* Added tests for ElasticTransform

* Try to format code as in https://github.com/pytorch/vision/pull/5106



* Fixed bug in affine get_params test

* Implemented RandomErase on PIL input as fallback to tensors (#6309)

Added tests

* Added image_size computation for BoundingBox.rotate if expand (#6319)

* Added image_size computation for BoundingBox.rotate if expand

* Added tests

* Added erase_image_pil and eager/jit erase_image_tensor test (#6320)

* Updates according to the review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 0ed5d811
...@@ -634,7 +634,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] ...@@ -634,7 +634,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
cmax = torch.ceil((max_vals / tol).trunc_() * tol) cmax = torch.ceil((max_vals / tol).trunc_() * tol)
cmin = torch.floor((min_vals / tol).trunc_() * tol) cmin = torch.floor((min_vals / tol).trunc_() * tol)
size = cmax - cmin size = cmax - cmin
return int(size[0]), int(size[1]) return int(size[0]), int(size[1]) # w, h
def rotate( def rotate(
...@@ -932,6 +932,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool ...@@ -932,6 +932,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
return img return img
def _create_identity_grid(size: List[int]) -> Tensor:
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
def elastic_transform( def elastic_transform(
img: Tensor, img: Tensor,
displacement: Tensor, displacement: Tensor,
...@@ -945,8 +951,6 @@ def elastic_transform( ...@@ -945,8 +951,6 @@ def elastic_transform(
size = list(img.shape[-2:]) size = list(img.shape[-2:])
displacement = displacement.to(img.device) displacement = displacement.to(img.device)
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] identity_grid = _create_identity_grid(size)
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
identity_grid = torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
grid = identity_grid.to(img.device) + displacement grid = identity_grid.to(img.device) + displacement
return _apply_grid_transform(img, grid, interpolation, fill) return _apply_grid_transform(img, grid, interpolation, fill)
...@@ -1855,7 +1855,7 @@ def _check_sequence_input(x, name, req_sizes): ...@@ -1855,7 +1855,7 @@ def _check_sequence_input(x, name, req_sizes):
if not isinstance(x, Sequence): if not isinstance(x, Sequence):
raise TypeError(f"{name} should be a sequence of length {msg}.") raise TypeError(f"{name} should be a sequence of length {msg}.")
if len(x) not in req_sizes: if len(x) not in req_sizes:
raise ValueError(f"{name} should be sequence of length {msg}.") raise ValueError(f"{name} should be a sequence of length {msg}.")
def _setup_angle(x, name, req_sizes=(2,)): def _setup_angle(x, name, req_sizes=(2,)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment