You need to sign in or sign up before continuing.
Unverified Commit 754c954f authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added scripted fn save test for random erase (#2767)

parent 5320f742
...@@ -741,10 +741,6 @@ class Tester(TransformsTester): ...@@ -741,10 +741,6 @@ class Tester(TransformsTester):
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0 batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
) )
def test_convert_image_dtype(self):
# TODO: add tests of CPU/CUDA on tensor and batch
pass
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester): class CUDATester(Tester):
......
...@@ -490,6 +490,9 @@ class Tester(TransformsTester): ...@@ -490,6 +490,9 @@ class Tester(TransformsTester):
self._test_transform_vs_scripted(fn, scripted_fn, tensor) self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
def test_convert_image_dtype(self): def test_convert_image_dtype(self):
tensor, _ = self._create_data(26, 34, device=self.device) tensor, _ = self._create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device) batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment