Unverified Commit 28fcf006 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove hack for dynamic modules and use Python functions instead (#22537)

parent 871598be
...@@ -13,14 +13,12 @@ ...@@ -13,14 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utilities to dynamically load objects from the Hub.""" """Utilities to dynamically load objects from the Hub."""
import filecmp
import importlib import importlib
import os import os
import re import re
import shutil import shutil
import subprocess
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
...@@ -45,6 +43,7 @@ def init_hf_modules(): ...@@ -45,6 +43,7 @@ def init_hf_modules():
init_path = Path(HF_MODULES_CACHE) / "__init__.py" init_path = Path(HF_MODULES_CACHE) / "__init__.py"
if not init_path.exists(): if not init_path.exists():
init_path.touch() init_path.touch()
importlib.invalidate_caches()
def create_dynamic_module(name: Union[str, os.PathLike]): def create_dynamic_module(name: Union[str, os.PathLike]):
...@@ -60,6 +59,7 @@ def create_dynamic_module(name: Union[str, os.PathLike]): ...@@ -60,6 +59,7 @@ def create_dynamic_module(name: Union[str, os.PathLike]):
init_path = dynamic_module_path / "__init__.py" init_path = dynamic_module_path / "__init__.py"
if not init_path.exists(): if not init_path.exists():
init_path.touch() init_path.touch()
importlib.invalidate_caches()
def get_relative_imports(module_file): def get_relative_imports(module_file):
...@@ -148,34 +148,8 @@ def get_class_in_module(class_name, module_path): ...@@ -148,34 +148,8 @@ def get_class_in_module(class_name, module_path):
""" """
Import a module on the cache directory for modules and extract a class from it. Import a module on the cache directory for modules and extract a class from it.
""" """
with tempfile.TemporaryDirectory() as tmp_dir:
module_dir = Path(HF_MODULES_CACHE) / os.path.dirname(module_path)
module_file_name = module_path.split(os.path.sep)[-1] + ".py"
# Copy to a temporary directory. We need to do this in another process to avoid strange and flaky error
# `ModuleNotFoundError: No module named 'transformers_modules.[module_dir_name].modeling'`
shutil.copy(f"{module_dir}/{module_file_name}", tmp_dir)
# On Windows, we need this character `r` before the path argument of `os.remove`
cmd = f'import os; os.remove(r"{module_dir}{os.path.sep}{module_file_name}")'
# We don't know which python binary file exists in an environment. For example, if `python3` exists but not
# `python`, the call `subprocess.run(["python", ...])` gives `FileNotFoundError` (about python binary). Notice
# that, if the file to be removed is not found, we also have `FileNotFoundError`, but it is not raised to the
# caller's process.
try:
subprocess.run(["python", "-c", cmd])
except FileNotFoundError:
try:
subprocess.run(["python3", "-c", cmd])
except FileNotFoundError:
pass
# copy back the file that we want to import
shutil.copyfile(f"{tmp_dir}/{module_file_name}", f"{module_dir}/{module_file_name}")
# import the module
module_path = module_path.replace(os.path.sep, ".") module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
...@@ -273,13 +247,21 @@ def get_cached_module_file( ...@@ -273,13 +247,21 @@ def get_cached_module_file(
create_dynamic_module(full_submodule) create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]: if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
# We always copy local files (we could hash the file to see if there was a change, and give them the name of # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
# that hash, to only copy when there is a modification but it seems overkill for now). # has changed since last copy.
# The only reason we do the copy is to avoid putting too many folders in sys.path. if not (submodule_path / module_file).exists() or not filecmp.cmp(
resolved_module_file, str(submodule_path / module_file)
):
shutil.copy(resolved_module_file, submodule_path / module_file) shutil.copy(resolved_module_file, submodule_path / module_file)
importlib.invalidate_caches()
for module_needed in modules_needed: for module_needed in modules_needed:
module_needed = f"{module_needed}.py" module_needed = f"{module_needed}.py"
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
if not (submodule_path / module_needed).exists() or not filecmp.cmp(
module_needed_file, str(submodule_path / module_needed)
):
shutil.copy(module_needed_file, submodule_path / module_needed)
importlib.invalidate_caches()
else: else:
# Get the commit hash # Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here. # TODO: we will get this info in the etag soon, so retrieve it from there and not here.
...@@ -293,6 +275,7 @@ def get_cached_module_file( ...@@ -293,6 +275,7 @@ def get_cached_module_file(
if not (submodule_path / module_file).exists(): if not (submodule_path / module_file).exists():
shutil.copy(resolved_module_file, submodule_path / module_file) shutil.copy(resolved_module_file, submodule_path / module_file)
importlib.invalidate_caches()
# Make sure we also have every file with relative # Make sure we also have every file with relative
for module_needed in modules_needed: for module_needed in modules_needed:
if not (submodule_path / module_needed).exists(): if not (submodule_path / module_needed).exists():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment