Unverified Commit 0a5ef036 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make `add-new-model-like` work in an env without all frameworks (#16239)

* Make add-new-model-like work without all frameworks installed

* A few fixes

* Last default frameworks
parent f4669364
......@@ -25,6 +25,7 @@ from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
import transformers.models.auto as auto_module
from transformers.models.auto.configuration_auto import model_type_to_module_name
from ..file_utils import is_flax_available, is_tf_available, is_torch_available
from ..utils import logging
from . import BaseTransformersCLICommand
......@@ -501,7 +502,7 @@ def filter_framework_files(
`List[Union[str, os.PathLike]]`: The list of filtered files.
"""
if frameworks is None:
return files
frameworks = get_default_frameworks()
framework_to_file = {}
others = []
......@@ -598,6 +599,20 @@ def find_base_model_checkpoint(
return ""
def get_default_frameworks():
"""
Returns the list of frameworks (PyTorch, TensorFlow, Flax) that are installed in the environment.
"""
frameworks = []
if is_torch_available():
frameworks.append("pt")
if is_tf_available():
frameworks.append("tf")
if is_flax_available():
frameworks.append("flax")
return frameworks
_re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES")
......@@ -616,17 +631,19 @@ def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = No
that framework as values.
"""
if frameworks is None:
frameworks = ["pt", "tf", "flax"]
frameworks = get_default_frameworks()
modules = {
"pt": auto_module.modeling_auto,
"tf": auto_module.modeling_tf_auto,
"flax": auto_module.modeling_flax_auto,
"pt": auto_module.modeling_auto if is_torch_available() else None,
"tf": auto_module.modeling_tf_auto if is_tf_available() else None,
"flax": auto_module.modeling_flax_auto if is_flax_available() else None,
}
model_classes = {}
for framework in frameworks:
new_model_classes = []
if modules[framework] is None:
raise ValueError(f"You selected {framework} in the frameworks, but it is not installed.")
model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None]
for model_mapping_name in model_mappings:
model_mapping = getattr(modules[framework], model_mapping_name)
......@@ -683,9 +700,9 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
available_frameworks.append("pt")
if frameworks is None:
frameworks = available_frameworks.copy()
else:
frameworks = [f for f in frameworks if f in available_frameworks]
frameworks = get_default_frameworks()
frameworks = [f for f in frameworks if f in available_frameworks]
model_classes = retrieve_model_classes(model_type, frameworks=frameworks)
......@@ -738,7 +755,7 @@ def clean_frameworks_in_init(
Whether or not to keep the preprocessing (tokenizer, feature extractor, processor) imports in the init.
"""
if frameworks is None:
frameworks = ["pt", "tf", "flax"]
frameworks = get_default_frameworks()
names = {"pt": "torch"}
to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks]
......@@ -1040,7 +1057,7 @@ def duplicate_doc_file(
content = f.read()
if frameworks is None:
frameworks = ["pt", "tf", "flax"]
frameworks = get_default_frameworks()
if dest_file is None:
dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.mdx"
......@@ -1302,7 +1319,7 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand):
self.old_model_type = config["old_model_type"]
self.model_patterns = ModelPatterns(**config["new_model_patterns"])
self.add_copied_from = config.get("add_copied_from", True)
self.frameworks = config.get("frameworks", ["pt", "tf", "flax"])
self.frameworks = config.get("frameworks", get_default_frameworks())
self.old_checkpoint = config.get("old_checkpoint", None)
else:
(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment