"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "73893fc771396a7645f68d87805b419169e7ee2d"
Unverified Commit 0a5ef036 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make `add-new-model-like` work in an env without all frameworks (#16239)

* Make add-new-model-like work without all frameworks installed

* A few fixes

* Last default frameworks
parent f4669364
...@@ -25,6 +25,7 @@ from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union ...@@ -25,6 +25,7 @@ from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
import transformers.models.auto as auto_module import transformers.models.auto as auto_module
from transformers.models.auto.configuration_auto import model_type_to_module_name from transformers.models.auto.configuration_auto import model_type_to_module_name
from ..file_utils import is_flax_available, is_tf_available, is_torch_available
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
...@@ -501,7 +502,7 @@ def filter_framework_files( ...@@ -501,7 +502,7 @@ def filter_framework_files(
`List[Union[str, os.PathLike]]`: The list of filtered files. `List[Union[str, os.PathLike]]`: The list of filtered files.
""" """
if frameworks is None: if frameworks is None:
return files frameworks = get_default_frameworks()
framework_to_file = {} framework_to_file = {}
others = [] others = []
...@@ -598,6 +599,20 @@ def find_base_model_checkpoint( ...@@ -598,6 +599,20 @@ def find_base_model_checkpoint(
return "" return ""
def get_default_frameworks():
"""
Returns the list of frameworks (PyTorch, TensorFlow, Flax) that are installed in the environment.
"""
frameworks = []
if is_torch_available():
frameworks.append("pt")
if is_tf_available():
frameworks.append("tf")
if is_flax_available():
frameworks.append("flax")
return frameworks
_re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES") _re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES")
...@@ -616,17 +631,19 @@ def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = No ...@@ -616,17 +631,19 @@ def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = No
that framework as values. that framework as values.
""" """
if frameworks is None: if frameworks is None:
frameworks = ["pt", "tf", "flax"] frameworks = get_default_frameworks()
modules = { modules = {
"pt": auto_module.modeling_auto, "pt": auto_module.modeling_auto if is_torch_available() else None,
"tf": auto_module.modeling_tf_auto, "tf": auto_module.modeling_tf_auto if is_tf_available() else None,
"flax": auto_module.modeling_flax_auto, "flax": auto_module.modeling_flax_auto if is_flax_available() else None,
} }
model_classes = {} model_classes = {}
for framework in frameworks: for framework in frameworks:
new_model_classes = [] new_model_classes = []
if modules[framework] is None:
raise ValueError(f"You selected {framework} in the frameworks, but it is not installed.")
model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None] model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None]
for model_mapping_name in model_mappings: for model_mapping_name in model_mappings:
model_mapping = getattr(modules[framework], model_mapping_name) model_mapping = getattr(modules[framework], model_mapping_name)
...@@ -683,9 +700,9 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None): ...@@ -683,9 +700,9 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
available_frameworks.append("pt") available_frameworks.append("pt")
if frameworks is None: if frameworks is None:
frameworks = available_frameworks.copy() frameworks = get_default_frameworks()
else:
frameworks = [f for f in frameworks if f in available_frameworks] frameworks = [f for f in frameworks if f in available_frameworks]
model_classes = retrieve_model_classes(model_type, frameworks=frameworks) model_classes = retrieve_model_classes(model_type, frameworks=frameworks)
...@@ -738,7 +755,7 @@ def clean_frameworks_in_init( ...@@ -738,7 +755,7 @@ def clean_frameworks_in_init(
Whether or not to keep the preprocessing (tokenizer, feature extractor, processor) imports in the init. Whether or not to keep the preprocessing (tokenizer, feature extractor, processor) imports in the init.
""" """
if frameworks is None: if frameworks is None:
frameworks = ["pt", "tf", "flax"] frameworks = get_default_frameworks()
names = {"pt": "torch"} names = {"pt": "torch"}
to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks] to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks]
...@@ -1040,7 +1057,7 @@ def duplicate_doc_file( ...@@ -1040,7 +1057,7 @@ def duplicate_doc_file(
content = f.read() content = f.read()
if frameworks is None: if frameworks is None:
frameworks = ["pt", "tf", "flax"] frameworks = get_default_frameworks()
if dest_file is None: if dest_file is None:
dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.mdx" dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.mdx"
...@@ -1302,7 +1319,7 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand): ...@@ -1302,7 +1319,7 @@ class AddNewModelLikeCommand(BaseTransformersCLICommand):
self.old_model_type = config["old_model_type"] self.old_model_type = config["old_model_type"]
self.model_patterns = ModelPatterns(**config["new_model_patterns"]) self.model_patterns = ModelPatterns(**config["new_model_patterns"])
self.add_copied_from = config.get("add_copied_from", True) self.add_copied_from = config.get("add_copied_from", True)
self.frameworks = config.get("frameworks", ["pt", "tf", "flax"]) self.frameworks = config.get("frameworks", get_default_frameworks())
self.old_checkpoint = config.get("old_checkpoint", None) self.old_checkpoint = config.get("old_checkpoint", None)
else: else:
( (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment