Unverified Commit d0dede0e authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Speed up Model tests by 20% (#5574)

* Measuring execution times of models.

* Speed up models by avoiding re-estimation of eager output

* Fixing linter

* Reduce input size for big models

* Speed up jit check method.

* Add simple jitscript fallback check for flaky models.

* Restore pytest filtering

* Fixing linter
parent cddad9ca
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -4,7 +4,6 @@ import operator
import os
import pkgutil
import sys
import traceback
import warnings
from collections import OrderedDict
from tempfile import TemporaryDirectory
......@@ -119,27 +118,16 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False, eager_out=None):
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
def assert_export_import_module(m, args):
"""Check that the results of a model are the same after saving and loading"""
def get_export_import_copy(m):
"""Save and load a TorchScript model"""
with TemporaryDirectory() as dir:
path = os.path.join(dir, "script.pt")
m.save(path)
imported = torch.jit.load(path)
return imported
m_import = get_export_import_copy(m)
with torch.no_grad(), freeze_rng_state():
results = m(*args)
with torch.no_grad(), freeze_rng_state():
results_from_imported = m_import(*args)
tol = 3e-4
torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol)
def get_export_import_copy(m):
"""Save and load a TorchScript model"""
with TemporaryDirectory() as dir:
path = os.path.join(dir, "script.pt")
m.save(path)
imported = torch.jit.load(path)
return imported
TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1"
if not TEST_WITH_SLOW or skip:
......@@ -157,8 +145,10 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
sm = torch.jit.script(nn_module)
with torch.no_grad(), freeze_rng_state():
eager_out = nn_module(*args)
if eager_out is None:
with torch.no_grad(), freeze_rng_state():
if unwrapper:
eager_out = nn_module(*args)
with torch.no_grad(), freeze_rng_state():
script_out = sm(*args)
......@@ -166,14 +156,22 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
script_out = unwrapper(script_out)
torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
assert_export_import_module(sm, args)
m_import = get_export_import_copy(sm)
with torch.no_grad(), freeze_rng_state():
imported_script_out = m_import(*args)
if unwrapper:
imported_script_out = unwrapper(imported_script_out)
torch.testing.assert_close(script_out, imported_script_out, atol=3e-4, rtol=3e-4)
def _check_fx_compatible(model, inputs):
def _check_fx_compatible(model, inputs, eager_out=None):
model_fx = torch.fx.symbolic_trace(model)
out = model(inputs)
out_fx = model_fx(inputs)
torch.testing.assert_close(out, out_fx)
if eager_out is None:
eager_out = model(inputs)
fx_out = model_fx(inputs)
torch.testing.assert_close(eager_out, fx_out)
def _check_input_backprop(model, inputs):
......@@ -298,6 +296,24 @@ _model_params = {
"rpn_post_nms_top_n_test": 1000,
},
}
# speeding up slow models:
slow_models = [
"convnext_base",
"convnext_large",
"resnext101_32x8d",
"wide_resnet101_2",
"efficientnet_b6",
"efficientnet_b7",
"efficientnet_v2_m",
"efficientnet_v2_l",
"regnet_y_16gf",
"regnet_y_32gf",
"regnet_y_128gf",
"regnet_x_16gf",
"regnet_x_32gf",
]
for m in slow_models:
_model_params[m] = {"input_shape": (1, 3, 64, 64)}
# The following contains configuration and expected values to be used tests that are model specific
......@@ -564,8 +580,8 @@ def test_classification_model(model_fn, dev):
out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
......@@ -595,7 +611,7 @@ def test_segmentation_model(model_fn, dev):
model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)["out"]
out = model(x)
def check_out(out):
prec = 0.01
......@@ -615,17 +631,17 @@ def test_segmentation_model(model_fn, dev):
return True # Full validation performed
full_validation = check_out(out)
full_validation = check_out(out["out"])
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
out = model(x)["out"]
out = model(x)
# See autocast_flaky_numerics comment at top of file.
if model_name not in autocast_flaky_numerics:
full_validation &= check_out(out)
full_validation &= check_out(out["out"])
if not full_validation:
msg = (
......@@ -716,7 +732,7 @@ def test_detection_model(model_fn, dev):
return True # Full validation performed
full_validation = check_out(out)
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None))
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
......@@ -780,8 +796,8 @@ def test_video_model(model_fn, dev):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)
assert out.shape[-1] == 50
if dev == torch.device("cuda"):
......@@ -821,8 +837,13 @@ def test_quantized_classification_model(model_fn):
if model_name not in quantized_flaky_models:
_assert_expected(out, model_name + "_quantized", prec=0.1)
assert out.shape[-1] == 5
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)
else:
try:
torch.jit.script(model)
except Exception as e:
raise AssertionError("model cannot be scripted.") from e
kwargs["quantize"] = False
for eval_mode in [True, False]:
......@@ -843,12 +864,6 @@ def test_quantized_classification_model(model_fn):
torch.ao.quantization.convert(model, inplace=True)
try:
torch.jit.script(model)
except Exception as e:
tb = traceback.format_exc()
raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment