Unverified Commit 0b072304 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Allow relative imports in dynamic code (#15352)

* Allow dynamic modules to use relative imports

* Add tests

* Add one last test

* Changes
parent 628b59e5
...@@ -22,6 +22,8 @@ import sys ...@@ -22,6 +22,8 @@ import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info
from ...file_utils import ( from ...file_utils import (
HF_MODULES_CACHE, HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME, TRANSFORMERS_DYNAMIC_MODULE_NAME,
...@@ -79,6 +81,12 @@ def check_imports(filename): ...@@ -79,6 +81,12 @@ def check_imports(filename):
# Only keep the top-level module # Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
# Imports of the form `import .xxx`
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from .xxx import yyy`
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
relative_imports = list(set(relative_imports))
# Unique-ify and test we got them all # Unique-ify and test we got them all
imports = list(set(imports)) imports = list(set(imports))
missing_packages = [] missing_packages = []
...@@ -94,6 +102,8 @@ def check_imports(filename): ...@@ -94,6 +102,8 @@ def check_imports(filename):
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
) )
return relative_imports
def get_class_in_module(class_name, module_path): def get_class_in_module(class_name, module_path):
""" """
...@@ -104,10 +114,9 @@ def get_class_in_module(class_name, module_path): ...@@ -104,10 +114,9 @@ def get_class_in_module(class_name, module_path):
return getattr(module, class_name) return getattr(module, class_name)
def get_class_from_dynamic_module( def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str, module_file: str,
class_name: str,
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
resume_download: bool = False, resume_download: bool = False,
...@@ -115,17 +124,10 @@ def get_class_from_dynamic_module( ...@@ -115,17 +124,10 @@ def get_class_from_dynamic_module(
use_auth_token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
**kwargs,
): ):
""" """
Extracts a class from a module file, present in the local folder or repository of a model. Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
Transformers module.
<Tip warning={true}>
Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
therefore only be called on trusted repos.
</Tip>
Args: Args:
pretrained_model_name_or_path (`str` or `os.PathLike`): pretrained_model_name_or_path (`str` or `os.PathLike`):
...@@ -139,8 +141,6 @@ def get_class_from_dynamic_module( ...@@ -139,8 +141,6 @@ def get_class_from_dynamic_module(
module_file (`str`): module_file (`str`):
The name of the module file containing the class to look for. The name of the module file containing the class to look for.
class_name (`str`):
The name of the class to import in the module.
cache_dir (`str` or `os.PathLike`, *optional*): cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used. cache should not be used.
...@@ -169,15 +169,7 @@ def get_class_from_dynamic_module( ...@@ -169,15 +169,7 @@ def get_class_from_dynamic_module(
</Tip> </Tip>
Returns: Returns:
`type`: The class, dynamically imported from the module. `str`: The path to the module inside the cache."""
Examples:
```python
# Download module *modeling.py* from huggingface.co and cache then extract the class *MyBertModel* from this
# module.
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
```"""
if is_offline_mode() and not local_files_only: if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True") logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True local_files_only = True
...@@ -210,7 +202,7 @@ def get_class_from_dynamic_module( ...@@ -210,7 +202,7 @@ def get_class_from_dynamic_module(
raise raise
# Check we have all the requirements in our environment # Check we have all the requirements in our environment
check_imports(resolved_module_file) modules_needed = check_imports(resolved_module_file)
# Now we move the module inside our cached dynamic modules. # Now we move the module inside our cached dynamic modules.
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
...@@ -220,16 +212,131 @@ def get_class_from_dynamic_module( ...@@ -220,16 +212,131 @@ def get_class_from_dynamic_module(
# We always copy local files (we could hash the file to see if there was a change, and give them the name of # We always copy local files (we could hash the file to see if there was a change, and give them the name of
# that hash, to only copy when there is a modification but it seems overkill for now). # that hash, to only copy when there is a modification but it seems overkill for now).
# The only reason we do the copy is to avoid putting too many folders in sys.path. # The only reason we do the copy is to avoid putting too many folders in sys.path.
module_name = module_file
shutil.copy(resolved_module_file, submodule_path / module_file) shutil.copy(resolved_module_file, submodule_path / module_file)
for module_needed in modules_needed:
module_needed = f"{module_needed}.py"
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
else: else:
# The module file will end up being named module_file + the etag. This way we get the benefit of versioning. # Get the commit hash
resolved_module_file_name = Path(resolved_module_file).name # TODO: we will get this info in the etag soon, so retrieve it from there.
module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".") if isinstance(use_auth_token, str):
module_name = "_".join(module_name_parts) + ".py" token = use_auth_token
if not (submodule_path / module_name).exists(): elif use_auth_token is True:
shutil.copy(resolved_module_file, submodule_path / module_name) token = HfFolder.get_token()
else:
token = None
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
# benefit of versioning.
submodule_path = submodule_path / commit_hash
full_submodule = full_submodule + os.path.sep + commit_hash
create_dynamic_module(full_submodule)
if not (submodule_path / module_file).exists():
shutil.copy(resolved_module_file, submodule_path / module_file)
# Make sure we also have every file with relative
for module_needed in modules_needed:
if not (submodule_path / module_needed).exists():
get_cached_module_file(
pretrained_model_name_or_path,
f"{module_needed}.py",
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
)
return os.path.join(full_submodule, module_file)
def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
class_name: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
Extracts a class from a module file, present in the local folder or repository of a model.
<Tip warning={true}>
Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
therefore only be called on trusted repos.
</Tip>
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
under a user or organization name, like `dbmdz/bert-base-german-cased`.
- a path to a *directory* containing a configuration file saved using the
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
module_file (`str`):
The name of the module file containing the class to look for.
class_name (`str`):
The name of the class to import in the module.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision(`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
<Tip>
Passing `use_auth_token=True` is required when you want to use a private model.
</Tip>
Returns:
`type`: The class, dynamically imported from the module.
Examples:
```python
# Download module *modeling.py* from huggingface.co and cache then extract the class *MyBertModel* from this
# module.
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
```"""
# And lastly we get the class inside our newly created module # And lastly we get the class inside our newly created module
final_module = os.path.join(full_submodule, module_name.replace(".py", "")) final_module = get_cached_module_file(
return get_class_in_module(class_name, final_module) pretrained_model_name_or_path,
module_file,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
)
return get_class_in_module(class_name, final_module.replace(".py", ""))
...@@ -102,3 +102,7 @@ class AutoConfigTest(unittest.TestCase): ...@@ -102,3 +102,7 @@ class AutoConfigTest(unittest.TestCase):
"hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.", "hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.",
): ):
_ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo") _ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo")
def test_from_pretrained_dynamic_config(self):
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(config.__class__.__name__, "NewModelConfig")
...@@ -324,7 +324,7 @@ class AutoModelTest(unittest.TestCase): ...@@ -324,7 +324,7 @@ class AutoModelTest(unittest.TestCase):
for child, parent in [(a, b) for a in child_model for b in parent_model]: for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}" assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
def test_from_pretrained_dynamic_model(self): def test_from_pretrained_dynamic_model_local(self):
config = BertConfig( config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
...@@ -340,6 +340,14 @@ class AutoModelTest(unittest.TestCase): ...@@ -340,6 +340,14 @@ class AutoModelTest(unittest.TestCase):
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
def test_from_pretrained_dynamic_model_distant(self):
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel")
# This one uses a relative import to a util file, this checks it is downloaded and used properly.
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel")
def test_new_model_registration(self): def test_new_model_registration(self):
AutoConfig.register("new-model", NewModelConfig) AutoConfig.register("new-model", NewModelConfig)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment