Unverified Commit 1a1fea34 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix prototype RandomErasing test (#6472)

parent 6279089a
...@@ -972,8 +972,9 @@ class TestRandomErasing: ...@@ -972,8 +972,9 @@ class TestRandomErasing:
assert 0 <= i <= image.image_size[0] - h assert 0 <= i <= image.image_size[0] - h
assert 0 <= j <= image.image_size[1] - w assert 0 <= j <= image.image_size[1] - w
def test__transform(self, mocker): @pytest.mark.parametrize("p", [0, 1])
transform = transforms.RandomErasing() def test__transform(self, mocker, p):
transform = transforms.RandomErasing(p=p)
transform._transformed_types = (mocker.MagicMock,) transform._transformed_types = (mocker.MagicMock,)
i_sentinel = mocker.MagicMock() i_sentinel = mocker.MagicMock()
...@@ -989,11 +990,15 @@ class TestRandomErasing: ...@@ -989,11 +990,15 @@ class TestRandomErasing:
inpt_sentinel = mocker.MagicMock() inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.prototype.transforms._augment.F.erase") mock = mocker.patch("torchvision.prototype.transforms._augment.F.erase")
transform(inpt_sentinel) output = transform(inpt_sentinel)
mock.assert_called_once_with( if p:
inpt_sentinel, i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel mock.assert_called_once_with(
) inpt_sentinel, i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel
)
else:
mock.assert_not_called()
assert output is inpt_sentinel
class TestTransform: class TestTransform:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment