Unverified Commit 754c954f authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added scripted fn save test for random erase (#2767)

parent 5320f742
...@@ -741,10 +741,6 @@ class Tester(TransformsTester): ...@@ -741,10 +741,6 @@ class Tester(TransformsTester):
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0 batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
) )
def test_convert_image_dtype(self):
# TODO: add tests of CPU/CUDA on tensor and batch
pass
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester): class CUDATester(Tester):
......
...@@ -490,6 +490,9 @@ class Tester(TransformsTester): ...@@ -490,6 +490,9 @@ class Tester(TransformsTester):
self._test_transform_vs_scripted(fn, scripted_fn, tensor) self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
def test_convert_image_dtype(self): def test_convert_image_dtype(self):
tensor, _ = self._create_data(26, 34, device=self.device) tensor, _ = self._create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device) batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment