Unverified Commit 28fcf006 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove hack for dynamic modules and use Python functions instead (#22537)

parent 871598be
......@@ -13,14 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to dynamically load objects from the Hub."""
import filecmp
import importlib
import os
import re
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Dict, Optional, Union
......@@ -45,6 +43,7 @@ def init_hf_modules():
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
if not init_path.exists():
init_path.touch()
importlib.invalidate_caches()
def create_dynamic_module(name: Union[str, os.PathLike]):
......@@ -60,6 +59,7 @@ def create_dynamic_module(name: Union[str, os.PathLike]):
init_path = dynamic_module_path / "__init__.py"
if not init_path.exists():
init_path.touch()
importlib.invalidate_caches()
def get_relative_imports(module_file):
......@@ -148,35 +148,9 @@ def get_class_in_module(class_name, module_path):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
module_dir = Path(HF_MODULES_CACHE) / os.path.dirname(module_path)
module_file_name = module_path.split(os.path.sep)[-1] + ".py"
# Copy to a temporary directory. We need to do this in another process to avoid strange and flaky error
# `ModuleNotFoundError: No module named 'transformers_modules.[module_dir_name].modeling'`
shutil.copy(f"{module_dir}/{module_file_name}", tmp_dir)
# On Windows, we need this character `r` before the path argument of `os.remove`
cmd = f'import os; os.remove(r"{module_dir}{os.path.sep}{module_file_name}")'
# We don't know which python binary file exists in an environment. For example, if `python3` exists but not
# `python`, the call `subprocess.run(["python", ...])` gives `FileNotFoundError` (about python binary). Notice
# that, if the file to be removed is not found, we also have `FileNotFoundError`, but it is not raised to the
# caller's process.
try:
subprocess.run(["python", "-c", cmd])
except FileNotFoundError:
try:
subprocess.run(["python3", "-c", cmd])
except FileNotFoundError:
pass
# copy back the file that we want to import
shutil.copyfile(f"{tmp_dir}/{module_file_name}", f"{module_dir}/{module_file_name}")
# import the module
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
return getattr(module, class_name)
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
return getattr(module, class_name)
def get_cached_module_file(
......@@ -273,13 +247,21 @@ def get_cached_module_file(
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
# that hash, to only copy when there is a modification but it seems overkill for now).
# The only reason we do the copy is to avoid putting too many folders in sys.path.
shutil.copy(resolved_module_file, submodule_path / module_file)
# We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
# has changed since last copy.
if not (submodule_path / module_file).exists() or not filecmp.cmp(
resolved_module_file, str(submodule_path / module_file)
):
shutil.copy(resolved_module_file, submodule_path / module_file)
importlib.invalidate_caches()
for module_needed in modules_needed:
module_needed = f"{module_needed}.py"
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
if not (submodule_path / module_needed).exists() or not filecmp.cmp(
module_needed_file, str(submodule_path / module_needed)
):
shutil.copy(module_needed_file, submodule_path / module_needed)
importlib.invalidate_caches()
else:
# Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
......@@ -293,6 +275,7 @@ def get_cached_module_file(
if not (submodule_path / module_file).exists():
shutil.copy(resolved_module_file, submodule_path / module_file)
importlib.invalidate_caches()
# Make sure we also have every file with relative
for module_needed in modules_needed:
if not (submodule_path / module_needed).exists():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment