Unverified Commit 88b6b93d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Extend `RandomShortestSize` to support Video specific flavour of the augmentation (#6770)

* Extend RandomShortestSize to support Video specific flavour of the augmentation

* Adding a test.

* Apply changes from code review
parent e3238e5a
......@@ -1379,10 +1379,9 @@ class TestScaleJitter:
class TestRandomShortestSize:
def test__get_params(self, mocker):
@pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
def test__get_params(self, min_size, max_size, mocker):
spatial_size = (3, 10)
min_size = [5, 9]
max_size = 20
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
......@@ -1395,10 +1394,9 @@ class TestRandomShortestSize:
assert isinstance(size, tuple) and len(size) == 2
longer = max(size)
assert longer <= max_size
shorter = min(size)
if longer == max_size:
if max_size is not None:
assert longer <= max_size
assert shorter <= max_size
else:
assert shorter in min_size
......
......@@ -730,7 +730,7 @@ class RandomShortestSize(Transform):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: int,
max_size: Optional[int] = None,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
):
......@@ -744,7 +744,9 @@ class RandomShortestSize(Transform):
orig_height, orig_width = query_spatial_size(flat_inputs)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
r = min_size / min(orig_height, orig_width)
if self.max_size is not None:
r = min(r, self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment