Unverified Commit 3b0b6c01 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix minor issues with model tests. (#5576)

parent 57c8de7f
...@@ -81,7 +81,7 @@ def _get_expected_file(name=None): ...@@ -81,7 +81,7 @@ def _get_expected_file(name=None):
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names # Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files # We hardcode it here to avoid having to re-generate the reference files
expected_file = expected_file = os.path.join(expected_file_base, "ModelTester.test_" + name) expected_file = os.path.join(expected_file_base, "ModelTester.test_" + name)
expected_file += "_expect.pkl" expected_file += "_expect.pkl"
if not ACCEPT and not os.path.exists(expected_file): if not ACCEPT and not os.path.exists(expected_file):
...@@ -665,6 +665,7 @@ def test_detection_model(model_fn, dev): ...@@ -665,6 +665,7 @@ def test_detection_model(model_fn, dev):
assert len(out) == 1 assert len(out) == 1
def compact(tensor): def compact(tensor):
tensor = tensor.cpu()
size = tensor.size() size = tensor.size()
elements_per_sample = functools.reduce(operator.mul, size[1:], 1) elements_per_sample = functools.reduce(operator.mul, size[1:], 1)
if elements_per_sample > 30: if elements_per_sample > 30:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment