Unverified Commit f093d082 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

disable weight download and state dict loading for model tests (#4867)

* disable weight download and state dict loading for model tests

* fix indent

* debug

* nuclear option

* revert unrelated change

* cleanup

* add explanation

* typo
parent a23778c0
import contextlib
import functools import functools
import io import io
import operator import operator
import os import os
import pkgutil
import sys
import traceback import traceback
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
...@@ -14,7 +17,6 @@ from _utils_internal import get_relative_path ...@@ -14,7 +17,6 @@ from _utils_internal import get_relative_path
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda
from torchvision import models from torchvision import models
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
...@@ -23,6 +25,51 @@ def get_models_from_module(module): ...@@ -23,6 +25,51 @@ def get_models_from_module(module):
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
@pytest.fixture
def disable_weight_loading(mocker):
"""When testing models, the two slowest operations are the downloading of the weights to a file and loading them
into the model. Unless, you want to test against specific weights, these steps can be disabled without any
drawbacks.
Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
through all models in `torchvision.models` and will patch all occurrences of the function
`download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
no-ops.
.. warning:
Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
fixture if you want to compare the model output against reference values.
"""
starting_point = models
function_name = "load_state_dict_from_url"
method_name = "load_state_dict"
module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")}
targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"}
for name in module_names:
module = sys.modules.get(name)
if not module:
continue
if function_name in module.__dict__:
targets.add(f"{module.__name__}.{function_name}")
targets.update(
{
f"{module.__name__}.{obj.__name__}.{method_name}"
for obj in module.__dict__.values()
if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__
}
)
for target in targets:
# See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
with contextlib.suppress(AttributeError):
mocker.patch(target)
def _get_expected_file(name=None): def _get_expected_file(name=None):
# Determine expected file based on environment # Determine expected file based on environment
expected_file_base = get_relative_path(os.path.realpath(__file__), "expect") expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")
...@@ -762,7 +809,7 @@ def test_quantized_classification_model(model_fn): ...@@ -762,7 +809,7 @@ def test_quantized_classification_model(model_fn):
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) @pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
def test_detection_model_trainable_backbone_layers(model_fn): def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
model_name = model_fn.__name__ model_name = model_fn.__name__
max_trainable = _model_tests_values[model_name]["max_trainable"] max_trainable = _model_tests_values[model_name]["max_trainable"]
n_trainable_params = [] n_trainable_params = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment