"...text-generation-inference.git" did not exist on "e14ae3b5e9e8948f4113a8260e699fc8d46ededd"
Unverified Commit 96ec0e1d authored by eellison's avatar eellison Committed by GitHub
Browse files

Add expected result tests (#1377)

* add expected result tests

* fix wrong assertion

* start with only detection models

* remove unneeded rng setting

* fix test

* add tuple support

* update test

* syntax error

* treat .pkl files as binary data, see : https://git-scm.com/book/en/v2/Customizing-Git-Git-Attributes#_binary_files

* fix test

* fix elif

* Map tensor results and enforce maximum pickle size

* unrelated change

* larger rtol

* pass rtol atol around

* last commit i swear...

* respond to comments

* fix flake

* fix py2 flake
parent a8561215
*.pkl binary
......@@ -2,6 +2,12 @@ import os
import shutil
import tempfile
import contextlib
import unittest
import argparse
import sys
import torch
import errno
import __main__
@contextlib.contextmanager
......@@ -14,3 +20,134 @@ def get_tmp_dir(src=None, **kwargs):
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
ACCEPT = os.getenv('EXPECTTEST_ACCEPT')
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--accept', action='store_true')
args, remaining = parser.parse_known_args()
if not ACCEPT:
ACCEPT = args.accept
for i, arg in enumerate(sys.argv):
if arg == '--accept':
del sys.argv[i]
break
class MapNestedTensorObjectImpl(object):
def __init__(self, tensor_map_fn):
self.tensor_map_fn = tensor_map_fn
def __call__(self, object):
if isinstance(object, torch.Tensor):
return self.tensor_map_fn(object)
elif isinstance(object, dict):
mapped_dict = {}
for key, value in object.items():
mapped_dict[self(key)] = self(value)
return mapped_dict
elif isinstance(object, (list, tuple)):
mapped_iter = []
for iter in object:
mapped_iter.append(self(iter))
return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter)
else:
return object
def map_nested_tensor_object(object, tensor_map_fn):
impl = MapNestedTensorObjectImpl(tensor_map_fn)
return impl(object)
# adapted from TestCase in torch/test/common_utils to accept non-string
# inputs and set maximum binary size
class TestCase(unittest.TestCase):
def assertExpected(self, output, subname=None, rtol=None, atol=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using --accept.
If you call this multiple times in a single function, you must
give a unique subname each time.
"""
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix):]
return text
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
# lives, NOT where test/common_utils.py lives.
module_id = self.__class__.__module__
munged_id = remove_prefix(self.id(), module_id + ".")
test_file = os.path.realpath(sys.modules[module_id].__file__)
expected_file = os.path.join(os.path.dirname(test_file),
"expect",
munged_id)
subname_output = ""
if subname:
expected_file += "_" + subname
subname_output = " ({})".format(subname)
expected_file += "_expect.pkl"
expected = None
def accept_output(update_type):
print("Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, output))
torch.save(output, expected_file)
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
self.assertTrue(binary_size <= MAX_PICKLE_SIZE)
try:
expected = torch.load(expected_file)
except IOError as e:
if e.errno != errno.ENOENT:
raise
elif ACCEPT:
return accept_output("output")
else:
raise RuntimeError(
("I got this output for {}{}:\n\n{}\n\n"
"No expect file exists; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, subname_output, output, __main__.__file__, munged_id))
if ACCEPT:
equal = False
try:
equal = self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol)
except Exception:
equal = False
if not equal:
return accept_output("updated output")
else:
self.assertNestedTensorObjectsEqual(output, expected, rtol=rtol, atol=atol)
def assertNestedTensorObjectsEqual(self, a, b, rtol=None, atol=None):
self.assertEqual(type(a), type(b))
if isinstance(a, torch.Tensor):
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
elif isinstance(a, dict):
self.assertEqual(len(a), len(b))
for key, value in a.items():
self.assertTrue(key in b, "key: " + str(key))
self.assertNestedTensorObjectsEqual(value, b[key], rtol=rtol, atol=atol)
elif isinstance(a, (list, tuple)):
self.assertEqual(len(a), len(b))
for val1, val2 in zip(a, b):
self.assertNestedTensorObjectsEqual(val1, val2, rtol=rtol, atol=atol)
else:
self.assertEqual(a, b)
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
from common_utils import TestCase, map_nested_tensor_object
from collections import OrderedDict
from itertools import product
import torch
import numpy as np
from torchvision import models
import unittest
import traceback
import random
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_available_classification_models():
......@@ -42,7 +51,7 @@ torchub_models = [
]
class Tester(unittest.TestCase):
class ModelTester(TestCase):
def check_script(self, model, name):
if name not in torchub_models:
return
......@@ -78,6 +87,7 @@ class Tester(unittest.TestCase):
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
def _test_detection_model(self, name):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
self.check_script(model, name)
model.eval()
......@@ -87,6 +97,33 @@ class Tester(unittest.TestCase):
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)
def subsample_tensor(tensor):
num_elems = tensor.numel()
num_samples = 20
if num_elems <= num_samples:
return tensor
flat_tensor = tensor.flatten()
ith_index = num_elems // num_samples
return flat_tensor[ith_index - 1::ith_index]
def compute_mean_std(tensor):
# can't compute mean of integral tensor
tensor = tensor.to(torch.double)
mean = torch.mean(tensor)
std = torch.std(tensor)
return {"mean": mean, "std": std}
# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
# compare results with mean and std
if name == "maskrcnn_resnet50_fpn":
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
# mean values are small, use large rtol
self.assertExpected(test_value, rtol=.01, atol=.01)
else:
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor))
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])
......@@ -173,7 +210,7 @@ for model_name in get_available_classification_models():
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape)
setattr(Tester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name, do_test)
for model_name in get_available_segmentation_models():
......@@ -182,7 +219,7 @@ for model_name in get_available_segmentation_models():
def do_test(self, model_name=model_name):
self._test_segmentation_model(model_name)
setattr(Tester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name, do_test)
for model_name in get_available_detection_models():
......@@ -191,7 +228,7 @@ for model_name in get_available_detection_models():
def do_test(self, model_name=model_name):
self._test_detection_model(model_name)
setattr(Tester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name, do_test)
for model_name in get_available_video_models():
......@@ -199,7 +236,7 @@ for model_name in get_available_video_models():
def do_test(self, model_name=model_name):
self._test_video_model(model_name)
setattr(Tester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name, do_test)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment