Unverified Commit 5737ed27 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Move output to CPU before assert_expected (#6497)

* Move output to CPU before assert_expected

* Fixing `test_detection_model`

* check_device=False

* linter
parent 7cc2c95a
...@@ -113,7 +113,7 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None): ...@@ -113,7 +113,7 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
expected = torch.load(expected_file) expected = torch.load(expected_file)
rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally
atol = atol or prec atol = atol or prec
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False) torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, check_device=False)
def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None): def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None):
...@@ -671,7 +671,9 @@ def test_segmentation_model(model_fn, dev): ...@@ -671,7 +671,9 @@ def test_segmentation_model(model_fn, dev):
# predictions match. # predictions match.
expected_file = _get_expected_file(model_name) expected_file = _get_expected_file(model_name)
expected = torch.load(expected_file) expected = torch.load(expected_file)
torch.testing.assert_close(out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec) torch.testing.assert_close(
out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False
)
return False # Partial validation performed return False # Partial validation performed
return True # Full validation performed return True # Full validation performed
...@@ -893,7 +895,7 @@ def test_quantized_classification_model(model_fn): ...@@ -893,7 +895,7 @@ def test_quantized_classification_model(model_fn):
out = model(x) out = model(x)
if model_name not in quantized_flaky_models: if model_name not in quantized_flaky_models:
_assert_expected(out, model_name + "_quantized", prec=2e-2) _assert_expected(out.cpu(), model_name + "_quantized", prec=2e-2)
assert out.shape[-1] == 5 assert out.shape[-1] == 5
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out) _check_fx_compatible(model, x, eager_out=out)
...@@ -960,7 +962,7 @@ def test_raft(model_fn, scripted): ...@@ -960,7 +962,7 @@ def test_raft(model_fn, scripted):
flow_pred = preds[-1] flow_pred = preds[-1]
# Tolerance is fairly high, but there are 2 * H * W outputs to check # Tolerance is fairly high, but there are 2 * H * W outputs to check
# The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different # The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
_assert_expected(flow_pred, name=model_fn.__name__, atol=1e-2, rtol=1) _assert_expected(flow_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -35,4 +35,4 @@ def test_raft_stereo(model_fn, model_mode, dev): ...@@ -35,4 +35,4 @@ def test_raft_stereo(model_fn, model_mode, dev):
), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}" ), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"
# Test against expected file output # Test against expected file output
TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2) TM._assert_expected(depth_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment