Unverified Commit cbc36eb4 authored by MateuszGuzek's avatar MateuszGuzek Committed by GitHub
Browse files

Add filter parameters to `list_models()` (#7718)


Co-authored-by: default avatarMateusz Guzek <matguzek@meta.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 2d4484fb
...@@ -103,17 +103,18 @@ def test_weights_deserializable(name): ...@@ -103,17 +103,18 @@ def test_weights_deserializable(name):
assert pickle.loads(pickle.dumps(weights)) is weights assert pickle.loads(pickle.dumps(weights)) is weights
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow] "module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
) )
def test_list_models(module): def test_list_models(module):
def get_models_from_module(module):
return [
v.__name__
for k, v in module.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
]
a = set(get_models_from_module(module)) a = set(get_models_from_module(module))
b = set(x.replace("quantized_", "") for x in models.list_models(module)) b = set(x.replace("quantized_", "") for x in models.list_models(module))
...@@ -121,6 +122,65 @@ def test_list_models(module): ...@@ -121,6 +122,65 @@ def test_list_models(module):
assert a == b assert a == b
@pytest.mark.parametrize(
"include_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
"*not-existing-model-for-test?",
["*resnet*", "*alexnet*"],
["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
("*resnet*", "*alexnet*"),
set(["*resnet*", "*alexnet*"]),
],
)
@pytest.mark.parametrize(
"exclude_filters",
[
None,
[],
(),
"",
"*resnet*",
["*alexnet*"],
["*not-existing-model-for-test?"],
["resnet34", "*not-existing-model-for-test?"],
["resnet34", "*resnet1*"],
("resnet34", "*resnet1*"),
set(["resnet34", "*resnet1*"]),
],
)
def test_list_models_filters(include_filters, exclude_filters):
actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
classification_models = set(get_models_from_module(models))
if isinstance(include_filters, str):
include_filters = [include_filters]
if isinstance(exclude_filters, str):
exclude_filters = [exclude_filters]
if include_filters:
expected = set()
for include_f in include_filters:
include_f = include_f.strip("*?")
expected = expected | set(x for x in classification_models if include_f in x)
else:
expected = classification_models
if exclude_filters:
for exclude_f in exclude_filters:
exclude_f = exclude_f.strip("*?")
if exclude_f != "":
a_exclude = set(x for x in classification_models if exclude_f in x)
expected = expected - a_exclude
assert expected == actual
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, weight", "name, weight",
[ [
......
import fnmatch
import importlib import importlib
import inspect import inspect
import sys import sys
...@@ -6,7 +7,7 @@ from enum import Enum ...@@ -6,7 +7,7 @@ from enum import Enum
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from types import ModuleType from types import ModuleType
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
from torch import nn from torch import nn
...@@ -203,19 +204,43 @@ def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], C ...@@ -203,19 +204,43 @@ def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], C
return wrapper return wrapper
def list_models(module: Optional[ModuleType] = None) -> List[str]: def list_models(
module: Optional[ModuleType] = None,
include: Union[Iterable[str], str, None] = None,
exclude: Union[Iterable[str], str, None] = None,
) -> List[str]:
""" """
Returns a list with the names of registered models. Returns a list with the names of registered models.
Args: Args:
module (ModuleType, optional): The module from which we want to extract the available models. module (ModuleType, optional): The module from which we want to extract the available models.
include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is the union of individual filters.
exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
Returns: Returns:
models (list): A list with the names of available models. models (list): A list with the names of available models.
""" """
models = [ all_models = {
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__ k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
] }
if include:
models: Set[str] = set()
if isinstance(include, str):
include = [include]
for include_filter in include:
models = models | set(fnmatch.filter(all_models, include_filter))
else:
models = all_models
if exclude:
if isinstance(exclude, str):
exclude = [exclude]
for exclude_filter in exclude:
models = models - set(fnmatch.filter(all_models, exclude_filter))
return sorted(models) return sorted(models)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment