"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "174738c33e511f11ee7810772e9f80f6c734993a"
Unverified Commit 96ec0e1d authored by eellison's avatar eellison Committed by GitHub
Browse files

Add expected result tests (#1377)

* add expected result tests

* fix wrong assertion

* start with only detection models

* remove unneeded rng setting

* fix test

* add tuple support

* update test

* syntax error

* treat .pkl files as binary data, see : https://git-scm.com/book/en/v2/Customizing-Git-Git-Attributes#_binary_files

* fix test

* fix elif

* Map tensor results and enforce maximum pickle size

* unrelated change

* larger rtol

* pass rtol atol around

* last commit i swear...

* respond to comments

* fix flake

* fix py2 flake
parent a8561215
*.pkl binary
...@@ -2,6 +2,12 @@ import os ...@@ -2,6 +2,12 @@ import os
import shutil import shutil
import tempfile import tempfile
import contextlib import contextlib
import unittest
import argparse
import sys
import torch
import errno
import __main__
@contextlib.contextmanager @contextlib.contextmanager
...@@ -14,3 +20,134 @@ def get_tmp_dir(src=None, **kwargs): ...@@ -14,3 +20,134 @@ def get_tmp_dir(src=None, **kwargs):
yield tmp_dir yield tmp_dir
finally: finally:
shutil.rmtree(tmp_dir) shutil.rmtree(tmp_dir)
ACCEPT = os.getenv('EXPECTTEST_ACCEPT')
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--accept', action='store_true')
args, remaining = parser.parse_known_args()
if not ACCEPT:
ACCEPT = args.accept
for i, arg in enumerate(sys.argv):
if arg == '--accept':
del sys.argv[i]
break
class MapNestedTensorObjectImpl(object):
def __init__(self, tensor_map_fn):
self.tensor_map_fn = tensor_map_fn
def __call__(self, object):
if isinstance(object, torch.Tensor):
return self.tensor_map_fn(object)
elif isinstance(object, dict):
mapped_dict = {}
for key, value in object.items():
mapped_dict[self(key)] = self(value)
return mapped_dict
elif isinstance(object, (list, tuple)):
mapped_iter = []
for iter in object:
mapped_iter.append(self(iter))
return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter)
else:
return object
def map_nested_tensor_object(object, tensor_map_fn):
impl = MapNestedTensorObjectImpl(tensor_map_fn)
return impl(object)
# adapted from TestCase in torch/test/common_utils to accept non-string
# inputs and set maximum binary size
class TestCase(unittest.TestCase):
def assertExpected(self, output, subname=None, rtol=None, atol=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using --accept.
If you call this multiple times in a single function, you must
give a unique subname each time.
"""
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix):]
return text
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
# lives, NOT where test/common_utils.py lives.
module_id = self.__class__.__module__
munged_id = remove_prefix(self.id(), module_id + ".")
test_file = os.path.realpath(sys.modules[module_id].__file__)
expected_file = os.path.join(os.path.dirname(test_file),
"expect",
munged_id)
subname_output = ""
if subname:
expected_file += "_" + subname
subname_output = " ({})".format(subname)
expected_file += "_expect.pkl"
expected = None
def accept_output(update_type):
print("Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, output))
torch.save(output, expected_file)
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
self.assertTrue(binary_size <= MAX_PICKLE_SIZE)
try:
expected = torch.load(expected_file)
except IOError as e:
if e.errno != errno.ENOENT:
raise
elif ACCEPT:
return accept_output("output")
else:
raise RuntimeError(
("I got this output for {}{}:\n\n{}\n\n"
"No expect file exists; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, subname_output, output, __main__.__file__, munged_id))
if ACCEPT:
equal = False
try:
equal = self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol)
except Exception:
equal = False
if not equal:
return accept_output("updated output")
else:
self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol)
def assertNestedTensorObjectsEqual(self, a, b, rtol=None, atol=None):
self.assertEqual(type(a), type(b))
if isinstance(a, torch.Tensor):
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
elif isinstance(a, dict):
self.assertEqual(len(a), len(b))
for key, value in a.items():
self.assertTrue(key in b, "key: " + str(key))
self.assertNestedTensorObjectsEqual(value, b[key], rtol=rtol, atol=atol)
elif isinstance(a, (list, tuple)):
self.assertEqual(len(a), len(b))
for val1, val2 in zip(a, b):
self.assertNestedTensorObjectsEqual(val1, val2, rtol=rtol, atol=atol)
else:
self.assertEqual(a, b)
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
from common_utils import TestCase, map_nested_tensor_object
from collections import OrderedDict from collections import OrderedDict
from itertools import product from itertools import product
import torch import torch
import numpy as np
from torchvision import models from torchvision import models
import unittest import unittest
import traceback import traceback
import random
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_available_classification_models(): def get_available_classification_models():
...@@ -42,7 +51,7 @@ torchub_models = [ ...@@ -42,7 +51,7 @@ torchub_models = [
] ]
class Tester(unittest.TestCase): class ModelTester(TestCase):
def check_script(self, model, name): def check_script(self, model, name):
if name not in torchub_models: if name not in torchub_models:
return return
...@@ -78,6 +87,7 @@ class Tester(unittest.TestCase): ...@@ -78,6 +87,7 @@ class Tester(unittest.TestCase):
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
def _test_detection_model(self, name): def _test_detection_model(self, name):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
self.check_script(model, name) self.check_script(model, name)
model.eval() model.eval()
...@@ -87,6 +97,33 @@ class Tester(unittest.TestCase): ...@@ -87,6 +97,33 @@ class Tester(unittest.TestCase):
out = model(model_input) out = model(model_input)
self.assertIs(model_input[0], x) self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1) self.assertEqual(len(out), 1)
def subsample_tensor(tensor):
num_elems = tensor.numel()
num_samples = 20
if num_elems <= num_samples:
return tensor
flat_tensor = tensor.flatten()
ith_index = num_elems // num_samples
return flat_tensor[ith_index - 1::ith_index]
def compute_mean_std(tensor):
# can't compute mean of integral tensor
tensor = tensor.to(torch.double)
mean = torch.mean(tensor)
std = torch.std(tensor)
return {"mean": mean, "std": std}
# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
# compare results with mean and std
if name == "maskrcnn_resnet50_fpn":
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
# mean values are small, use large rtol
self.assertExpected(test_value, rtol=.01, atol=.01)
else:
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor))
self.assertTrue("boxes" in out[0]) self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0]) self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0]) self.assertTrue("labels" in out[0])
...@@ -173,7 +210,7 @@ for model_name in get_available_classification_models(): ...@@ -173,7 +210,7 @@ for model_name in get_available_classification_models():
input_shape = (1, 3, 299, 299) input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape) self._test_classification_model(model_name, input_shape)
setattr(Tester, "test_" + model_name, do_test) setattr(ModelTester, "test_" + model_name, do_test)
for model_name in get_available_segmentation_models(): for model_name in get_available_segmentation_models():
...@@ -182,7 +219,7 @@ for model_name in get_available_segmentation_models(): ...@@ -182,7 +219,7 @@ for model_name in get_available_segmentation_models():
def do_test(self, model_name=model_name): def do_test(self, model_name=model_name):
self._test_segmentation_model(model_name) self._test_segmentation_model(model_name)
setattr(Tester, "test_" + model_name, do_test) setattr(ModelTester, "test_" + model_name, do_test)
for model_name in get_available_detection_models(): for model_name in get_available_detection_models():
...@@ -191,7 +228,7 @@ for model_name in get_available_detection_models(): ...@@ -191,7 +228,7 @@ for model_name in get_available_detection_models():
def do_test(self, model_name=model_name): def do_test(self, model_name=model_name):
self._test_detection_model(model_name) self._test_detection_model(model_name)
setattr(Tester, "test_" + model_name, do_test) setattr(ModelTester, "test_" + model_name, do_test)
for model_name in get_available_video_models(): for model_name in get_available_video_models():
...@@ -199,7 +236,7 @@ for model_name in get_available_video_models(): ...@@ -199,7 +236,7 @@ for model_name in get_available_video_models():
def do_test(self, model_name=model_name): def do_test(self, model_name=model_name):
self._test_video_model(model_name) self._test_video_model(model_name)
setattr(Tester, "test_" + model_name, do_test) setattr(ModelTester, "test_" + model_name, do_test)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment