Unverified Commit d0dede0e authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Speed up Model tests by 20% (#5574)

* Measuring execution times of models.

* Speed up models by avoiding re-estimation of eager output

* Fixing linter

* Reduce input size for big models

* Speed up jit check method.

* Add simple jitscript fallback check for flaky models.

* Restore pytest filtering

* Fixing linter
parent cddad9ca
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -4,7 +4,6 @@ import operator ...@@ -4,7 +4,6 @@ import operator
import os import os
import pkgutil import pkgutil
import sys import sys
import traceback
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
...@@ -119,12 +118,9 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None): ...@@ -119,12 +118,9 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False) torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False, eager_out=None):
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported""" """Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
def assert_export_import_module(m, args):
"""Check that the results of a model are the same after saving and loading"""
def get_export_import_copy(m): def get_export_import_copy(m):
"""Save and load a TorchScript model""" """Save and load a TorchScript model"""
with TemporaryDirectory() as dir: with TemporaryDirectory() as dir:
...@@ -133,14 +129,6 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): ...@@ -133,14 +129,6 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
imported = torch.jit.load(path) imported = torch.jit.load(path)
return imported return imported
m_import = get_export_import_copy(m)
with torch.no_grad(), freeze_rng_state():
results = m(*args)
with torch.no_grad(), freeze_rng_state():
results_from_imported = m_import(*args)
tol = 3e-4
torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol)
TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1" TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1"
if not TEST_WITH_SLOW or skip: if not TEST_WITH_SLOW or skip:
# TorchScript is not enabled, skip these tests # TorchScript is not enabled, skip these tests
...@@ -157,7 +145,9 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): ...@@ -157,7 +145,9 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
sm = torch.jit.script(nn_module) sm = torch.jit.script(nn_module)
if eager_out is None:
with torch.no_grad(), freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
if unwrapper:
eager_out = nn_module(*args) eager_out = nn_module(*args)
with torch.no_grad(), freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
...@@ -166,14 +156,22 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): ...@@ -166,14 +156,22 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
script_out = unwrapper(script_out) script_out = unwrapper(script_out)
torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4) torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
assert_export_import_module(sm, args)
m_import = get_export_import_copy(sm)
with torch.no_grad(), freeze_rng_state():
imported_script_out = m_import(*args)
if unwrapper:
imported_script_out = unwrapper(imported_script_out)
torch.testing.assert_close(script_out, imported_script_out, atol=3e-4, rtol=3e-4)
def _check_fx_compatible(model, inputs): def _check_fx_compatible(model, inputs, eager_out=None):
model_fx = torch.fx.symbolic_trace(model) model_fx = torch.fx.symbolic_trace(model)
out = model(inputs) if eager_out is None:
out_fx = model_fx(inputs) eager_out = model(inputs)
torch.testing.assert_close(out, out_fx) fx_out = model_fx(inputs)
torch.testing.assert_close(eager_out, fx_out)
def _check_input_backprop(model, inputs): def _check_input_backprop(model, inputs):
...@@ -298,6 +296,24 @@ _model_params = { ...@@ -298,6 +296,24 @@ _model_params = {
"rpn_post_nms_top_n_test": 1000, "rpn_post_nms_top_n_test": 1000,
}, },
} }
# speeding up slow models:
slow_models = [
"convnext_base",
"convnext_large",
"resnext101_32x8d",
"wide_resnet101_2",
"efficientnet_b6",
"efficientnet_b7",
"efficientnet_v2_m",
"efficientnet_v2_l",
"regnet_y_16gf",
"regnet_y_32gf",
"regnet_y_128gf",
"regnet_x_16gf",
"regnet_x_32gf",
]
for m in slow_models:
_model_params[m] = {"input_shape": (1, 3, 64, 64)}
# The following contains configuration and expected values to be used tests that are model specific # The following contains configuration and expected values to be used tests that are model specific
...@@ -564,8 +580,8 @@ def test_classification_model(model_fn, dev): ...@@ -564,8 +580,8 @@ def test_classification_model(model_fn, dev):
out = model(x) out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1) _assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == num_classes assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x) _check_fx_compatible(model, x, eager_out=out)
if dev == torch.device("cuda"): if dev == torch.device("cuda"):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
...@@ -595,7 +611,7 @@ def test_segmentation_model(model_fn, dev): ...@@ -595,7 +611,7 @@ def test_segmentation_model(model_fn, dev):
model.eval().to(device=dev) model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
out = model(x)["out"] out = model(x)
def check_out(out): def check_out(out):
prec = 0.01 prec = 0.01
...@@ -615,17 +631,17 @@ def test_segmentation_model(model_fn, dev): ...@@ -615,17 +631,17 @@ def test_segmentation_model(model_fn, dev):
return True # Full validation performed return True # Full validation performed
full_validation = check_out(out) full_validation = check_out(out["out"])
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x) _check_fx_compatible(model, x, eager_out=out)
if dev == torch.device("cuda"): if dev == torch.device("cuda"):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
out = model(x)["out"] out = model(x)
# See autocast_flaky_numerics comment at top of file. # See autocast_flaky_numerics comment at top of file.
if model_name not in autocast_flaky_numerics: if model_name not in autocast_flaky_numerics:
full_validation &= check_out(out) full_validation &= check_out(out["out"])
if not full_validation: if not full_validation:
msg = ( msg = (
...@@ -716,7 +732,7 @@ def test_detection_model(model_fn, dev): ...@@ -716,7 +732,7 @@ def test_detection_model(model_fn, dev):
return True # Full validation performed return True # Full validation performed
full_validation = check_out(out) full_validation = check_out(out)
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None)) _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
if dev == torch.device("cuda"): if dev == torch.device("cuda"):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
...@@ -780,8 +796,8 @@ def test_video_model(model_fn, dev): ...@@ -780,8 +796,8 @@ def test_video_model(model_fn, dev):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
out = model(x) out = model(x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x) _check_fx_compatible(model, x, eager_out=out)
assert out.shape[-1] == 50 assert out.shape[-1] == 50
if dev == torch.device("cuda"): if dev == torch.device("cuda"):
...@@ -821,8 +837,13 @@ def test_quantized_classification_model(model_fn): ...@@ -821,8 +837,13 @@ def test_quantized_classification_model(model_fn):
if model_name not in quantized_flaky_models: if model_name not in quantized_flaky_models:
_assert_expected(out, model_name + "_quantized", prec=0.1) _assert_expected(out, model_name + "_quantized", prec=0.1)
assert out.shape[-1] == 5 assert out.shape[-1] == 5
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x) _check_fx_compatible(model, x, eager_out=out)
else:
try:
torch.jit.script(model)
except Exception as e:
raise AssertionError("model cannot be scripted.") from e
kwargs["quantize"] = False kwargs["quantize"] = False
for eval_mode in [True, False]: for eval_mode in [True, False]:
...@@ -843,12 +864,6 @@ def test_quantized_classification_model(model_fn): ...@@ -843,12 +864,6 @@ def test_quantized_classification_model(model_fn):
torch.ao.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
try:
torch.jit.script(model)
except Exception as e:
tb = traceback.format_exc()
raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) @pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading): def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment