Unverified Commit 0f770ac9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Setting seeds for TestRoiPool backward. (#4763)

parent 5ffba76f
...@@ -46,9 +46,11 @@ class RoIOpTester(ABC): ...@@ -46,9 +46,11 @@ class RoIOpTester(ABC):
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, device, contiguous): def test_backward(self, seed, device, contiguous):
torch.random.manual_seed(seed)
pool_size = 2 pool_size = 2
x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
if not contiguous: if not contiguous:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment