Unverified Commit 92e97419 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2765 from huggingface/extract-cached-archives

Add option to `cached_path` to automatically extract archives
parents 520e7f21 c6c5c3fd
......@@ -8,13 +8,16 @@ import fnmatch
import json
import logging
import os
import shutil
import sys
import tarfile
import tempfile
from contextlib import contextmanager
from functools import partial, wraps
from hashlib import sha256
from typing import Optional
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
import boto3
import requests
......@@ -203,7 +206,14 @@ def filename_to_url(filename, cache_dir=None):
def cached_path(
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
url_or_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
user_agent=None,
extract_compressed_file=False,
force_extract=False,
) -> Optional[str]:
"""
Given something that might be a URL (or might be a local path),
......@@ -215,6 +225,10 @@ def cached_path(
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
re-extract the archive and overide the folder where it was extracted.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
......@@ -229,7 +243,7 @@ def cached_path(
if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(
output_path = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
......@@ -239,7 +253,7 @@ def cached_path(
)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
output_path = url_or_filename
elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
......@@ -247,6 +261,39 @@ def cached_path(
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
if extract_compressed_file:
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
return output_path
# Path where we extract compressed archives
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
output_dir, output_file = os.path.split(output_path)
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
return output_path_extracted
# Prevent parallel extractions
lock_path = output_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path_extracted, ignore_errors=True)
os.makedirs(output_path_extracted)
if is_zipfile(output_path):
with ZipFile(output_path, "r") as zip_file:
zip_file.extractall(output_path_extracted)
zip_file.close()
elif tarfile.is_tarfile(output_path):
tar_file = tarfile.open(output_path)
tar_file.extractall(output_path_extracted)
tar_file.close()
else:
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
return output_path_extracted
return output_path
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
......
......@@ -303,10 +303,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME)
if from_pt:
raise EnvironmentError(
"Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name."
archive_file = hf_bucket_url(
pretrained_model_name_or_path, postfix=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME)
)
# redirect to the cache, if necessary
......
......@@ -421,10 +421,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME)
if from_tf:
raise EnvironmentError(
"Loading a PyTorch model from a TF checkpoint is not supported when using a model identifier name."
archive_file = hf_bucket_url(
pretrained_model_name_or_path, postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME)
)
# redirect to the cache, if necessary
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment