Unverified Commit 9acec20d authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

revert changes to assert_equal (#5586)

* revert changes to assert_equal

* add tolerances for batch vs single test
parent 090d8237
...@@ -137,7 +137,7 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu ...@@ -137,7 +137,7 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
return batch_tensor return batch_tensor
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=1e-6) assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None): def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
...@@ -195,7 +195,7 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs): ...@@ -195,7 +195,7 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
for i in range(len(batch_tensors)): for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...] img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs) transformed_img = fn(img_tensor, **fn_kwargs)
assert_equal(transformed_img, transformed_batch[i, ...]) torch.testing.assert_close(transformed_img, transformed_batch[i, ...], rtol=0, atol=1e-6)
if scripted_fn_atol >= 0: if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment