"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2c1ed50fc57f154768364e4506d7bab9daebf83d"
Unverified Commit 88b6b93d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Extend `RandomShortestSize` to support Video specific flavour of the augmentation (#6770)

* Extend RandomShortestSize to support Video specific flavour of the augmentation

* Adding a test.

* Apply changes from code review
parent e3238e5a
...@@ -1379,10 +1379,9 @@ class TestScaleJitter: ...@@ -1379,10 +1379,9 @@ class TestScaleJitter:
class TestRandomShortestSize: class TestRandomShortestSize:
def test__get_params(self, mocker): @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
def test__get_params(self, min_size, max_size, mocker):
spatial_size = (3, 10) spatial_size = (3, 10)
min_size = [5, 9]
max_size = 20
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
...@@ -1395,10 +1394,9 @@ class TestRandomShortestSize: ...@@ -1395,10 +1394,9 @@ class TestRandomShortestSize:
assert isinstance(size, tuple) and len(size) == 2 assert isinstance(size, tuple) and len(size) == 2
longer = max(size) longer = max(size)
assert longer <= max_size
shorter = min(size) shorter = min(size)
if longer == max_size: if max_size is not None:
assert longer <= max_size
assert shorter <= max_size assert shorter <= max_size
else: else:
assert shorter in min_size assert shorter in min_size
......
...@@ -730,7 +730,7 @@ class RandomShortestSize(Transform): ...@@ -730,7 +730,7 @@ class RandomShortestSize(Transform):
def __init__( def __init__(
self, self,
min_size: Union[List[int], Tuple[int], int], min_size: Union[List[int], Tuple[int], int],
max_size: int, max_size: Optional[int] = None,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
): ):
...@@ -744,7 +744,9 @@ class RandomShortestSize(Transform): ...@@ -744,7 +744,9 @@ class RandomShortestSize(Transform):
orig_height, orig_width = query_spatial_size(flat_inputs) orig_height, orig_width = query_spatial_size(flat_inputs)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) r = min_size / min(orig_height, orig_width)
if self.max_size is not None:
r = min(r, self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r) new_width = int(orig_width * r)
new_height = int(orig_height * r) new_height = int(orig_height * r)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment