Unverified Commit 7daa90ac authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Take assertExpected and check_jit_scriptable out of the TestCase class (#3947)

parent 4c0fdc61
......@@ -5,9 +5,7 @@ import contextlib
import unittest
import argparse
import sys
import io
import torch
import warnings
import __main__
import random
import inspect
......@@ -15,7 +13,6 @@ import inspect
from numbers import Number
from torch._six import string_classes
from collections import OrderedDict
from _utils_internal import get_relative_path
import numpy as np
from PIL import Image
......@@ -49,10 +46,6 @@ def set_rng_seed(seed):
np.random.seed(seed)
ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1'
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
class MapNestedTensorObjectImpl(object):
def __init__(self, tensor_map_fn):
self.tensor_map_fn = tensor_map_fn
......@@ -95,55 +88,6 @@ def is_iterable(obj):
class TestCase(unittest.TestCase):
precision = 1e-5
def _get_expected_file(self, name=None):
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
# lives, NOT where test/common_utils.py lives.
module_id = self.__class__.__module__
# Determine expected file based on environment
expected_file_base = get_relative_path(
os.path.realpath(sys.modules[module_id].__file__),
"expect")
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files
expected_file = expected_file = os.path.join(expected_file_base, 'ModelTester.test_' + name)
expected_file += "_expect.pkl"
if not ACCEPT and not os.path.exists(expected_file):
raise RuntimeError(
f"No expect file exists for {os.path.basename(expected_file)} in {expected_file}; "
"to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
"env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
)
return expected_file
def assertExpected(self, output, name, prec=None):
r"""
Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using an EXPECTTEST_ACCEPT=1 env variable.
"""
expected_file = self._get_expected_file(name)
if ACCEPT:
filename = {os.path.basename(expected_file)}
print("Accepting updated output for {}:\n\n{}".format(filename, output))
torch.save(output, expected_file)
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
if binary_size > MAX_PICKLE_SIZE:
raise RuntimeError("The output for {}, is larger than 50kb".format(filename))
else:
expected = torch.load(expected_file)
rtol = atol = prec or self.precision
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
"""
This is copied from pytorch/test/common_utils.py's TestCase.assertEqual
......@@ -261,58 +205,6 @@ class TestCase(unittest.TestCase):
else:
super(TestCase, self).assertEqual(x, y, message)
def check_jit_scriptable(self, nn_module, args, unwrapper=None, skip=False):
"""
Check that a nn.Module's results in TorchScript match eager and that it
can be exported
"""
if not TEST_WITH_SLOW or skip:
# TorchScript is not enabled, skip these tests
msg = "The check_jit_scriptable test for {} was skipped. " \
"This test checks if the module's results in TorchScript " \
"match eager and that it can be exported. To run these " \
"tests make sure you set the environment variable " \
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \
"manually skipped.".format(nn_module.__class__.__name__)
warnings.warn(msg, RuntimeWarning)
return None
sm = torch.jit.script(nn_module)
with freeze_rng_state():
eager_out = nn_module(*args)
with freeze_rng_state():
script_out = sm(*args)
if unwrapper:
script_out = unwrapper(script_out)
self.assertEqual(eager_out, script_out, prec=1e-4)
self.assertExportImportModule(sm, args)
return sm
def getExportImportCopy(self, m):
"""
Save and load a TorchScript model
"""
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
imported = torch.jit.load(buffer)
return imported
def assertExportImportModule(self, m, args):
"""
Check that the results of a model are the same after saving and loading
"""
m_import = self.getExportImportCopy(m)
with freeze_rng_state():
results = m(*args)
with freeze_rng_state():
results_from_imported = m_import(*args)
self.assertEqual(results, results_from_imported, prec=3e-4)
@contextlib.contextmanager
def freeze_rng_state():
......
import os
import io
import sys
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed, IN_CIRCLE_CI
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state, set_rng_seed
from _utils_internal import get_relative_path
from collections import OrderedDict
from itertools import product
import functools
......@@ -13,6 +16,9 @@ import warnings
import pytest
ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1'
def get_available_classification_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
......@@ -33,6 +39,103 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
def _get_expected_file(name=None):
# Determine expected file based on environment
expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files
expected_file = expected_file = os.path.join(expected_file_base, 'ModelTester.test_' + name)
expected_file += "_expect.pkl"
if not ACCEPT and not os.path.exists(expected_file):
raise RuntimeError(
f"No expect file exists for {os.path.basename(expected_file)} in {expected_file}; "
"to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
"env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
)
return expected_file
def _assert_expected(output, name, prec):
"""Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using an EXPECTTEST_ACCEPT=1 env variable.
"""
expected_file = _get_expected_file(name)
if ACCEPT:
filename = {os.path.basename(expected_file)}
print("Accepting updated output for {}:\n\n{}".format(filename, output))
torch.save(output, expected_file)
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
if binary_size > MAX_PICKLE_SIZE:
raise RuntimeError("The output for {}, is larger than 50kb".format(filename))
else:
expected = torch.load(expected_file)
rtol = atol = prec
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
def assert_export_import_module(m, args):
"""Check that the results of a model are the same after saving and loading"""
def get_export_import_copy(m):
"""Save and load a TorchScript model"""
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
imported = torch.jit.load(buffer)
return imported
m_import = get_export_import_copy(m)
with freeze_rng_state():
results = m(*args)
with freeze_rng_state():
results_from_imported = m_import(*args)
tol = 3e-4
try:
torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol)
except pytest.UsageError:
# custom check for the models that return named tuples:
# we compare field by field while ignoring None as assert_close can't handle None
for a, b in zip(results, results_from_imported):
if a is not None:
torch.testing.assert_close(a, b, atol=tol, rtol=tol)
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
if not TEST_WITH_SLOW or skip:
# TorchScript is not enabled, skip these tests
msg = "The check_jit_scriptable test for {} was skipped. " \
"This test checks if the module's results in TorchScript " \
"match eager and that it can be exported. To run these " \
"tests make sure you set the environment variable " \
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \
"manually skipped.".format(nn_module.__class__.__name__)
warnings.warn(msg, RuntimeWarning)
return None
sm = torch.jit.script(nn_module)
with freeze_rng_state():
eager_out = nn_module(*args)
with freeze_rng_state():
script_out = sm(*args)
if unwrapper:
script_out = unwrapper(script_out)
torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
assert_export_import_module(sm, args)
# If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode
......@@ -132,16 +235,16 @@ class ModelTester(TestCase):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertExpected(out.cpu(), name, prec=0.1)
_assert_expected(out.cpu(), name, prec=0.1)
self.assertEqual(out.shape[-1], 50)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(x)
# See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics:
self.assertExpected(out.cpu(), name, prec=0.1)
_assert_expected(out.cpu(), name, prec=0.1)
self.assertEqual(out.shape[-1], 50)
def _test_segmentation_model(self, name, dev):
......@@ -166,12 +269,12 @@ class ModelTester(TestCase):
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
# where we need to create a new expected result.
self.assertExpected(out.cpu(), name, prec=prec)
_assert_expected(out.cpu(), name, prec=prec)
except AssertionError:
# Unfortunately some segmentation models are flaky with autocast
# so instead of validating the probability scores, check that the class
# predictions match.
expected_file = self._get_expected_file(name)
expected_file = _get_expected_file(name)
expected = torch.load(expected_file)
torch.testing.assert_close(out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec)
return False # Partial validation performed
......@@ -180,7 +283,7 @@ class ModelTester(TestCase):
full_validation = check_out(out)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
......@@ -248,13 +351,13 @@ class ModelTester(TestCase):
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
# where we need to create a new expected result.
self.assertExpected(output, name, prec=prec)
_assert_expected(output, name, prec=prec)
except AssertionError:
# Unfortunately detection models are flaky due to the unstable sort
# in NMS. If matching across all outputs fails, use the same approach
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate
# scores.
expected_file = self._get_expected_file(name)
expected_file = _get_expected_file(name)
expected = torch.load(expected_file)
torch.testing.assert_close(output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec,
check_device=False, check_dtype=False)
......@@ -268,7 +371,7 @@ class ModelTester(TestCase):
return True # Full validation performed
full_validation = check_out(out)
self.check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None))
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None))
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
......@@ -318,7 +421,7 @@ class ModelTester(TestCase):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
self.assertEqual(out.shape[-1], 50)
if dev == torch.device("cuda"):
......@@ -398,7 +501,7 @@ class ModelTester(TestCase):
model.AuxLogits = None
model = model.eval()
x = torch.rand(1, 3, 299, 299)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
def test_fasterrcnn_double(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
......@@ -427,7 +530,7 @@ class ModelTester(TestCase):
model.aux2 = None
model = model.eval()
x = torch.rand(1, 3, 224, 224)
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
def test_fasterrcnn_switch_devices(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment