"examples/nas/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "3ec26b40afd88b8255ebc74b46f475b77aa4d19b"
Commit 7d99e05f authored by thomwolf's avatar thomwolf
Browse files

file_cache has options to extract archives

parent 2c12464a
...@@ -8,13 +8,16 @@ import fnmatch ...@@ -8,13 +8,16 @@ import fnmatch
import json import json
import logging import logging
import os import os
import shutil
import sys import sys
import tarfile
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial, wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from typing import Optional from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
import boto3 import boto3
import requests import requests
...@@ -203,7 +206,14 @@ def filename_to_url(filename, cache_dir=None): ...@@ -203,7 +206,14 @@ def filename_to_url(filename, cache_dir=None):
def cached_path( def cached_path(
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None url_or_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
user_agent=None,
extract_compressed_file=False,
force_extract=False,
) -> Optional[str]: ) -> Optional[str]:
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
...@@ -215,6 +225,10 @@ def cached_path( ...@@ -215,6 +225,10 @@ def cached_path(
force_download: if True, re-dowload the file even if it's already cached in the cache dir. force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found. resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests. user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
re-extract the archive and overide the folder where it was extracted.
Return: Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
...@@ -229,7 +243,7 @@ def cached_path( ...@@ -229,7 +243,7 @@ def cached_path(
if is_remote_url(url_or_filename): if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
return get_from_cache( output_path = get_from_cache(
url_or_filename, url_or_filename,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
...@@ -239,7 +253,7 @@ def cached_path( ...@@ -239,7 +253,7 @@ def cached_path(
) )
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
return url_or_filename output_path = url_or_filename
elif urlparse(url_or_filename).scheme == "": elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist. # File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename)) raise EnvironmentError("file {} not found".format(url_or_filename))
...@@ -247,6 +261,37 @@ def cached_path( ...@@ -247,6 +261,37 @@ def cached_path(
# Something unknown # Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
if extract_compressed_file:
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
return output_path
# Path where we extract compressed archives
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
output_dir, output_file = os.path.split(output_path)
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
return output_path_extracted
# Prevent parallel extractions
lock_path = output_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path_extracted, ignore_errors=True)
os.makedirs(output_path_extracted)
if is_zipfile(output_path):
with ZipFile(output_path, "r") as zip_file:
zip_file.extractall(output_path_extracted)
zip_file.close()
elif tarfile.is_tarfile(output_path):
tar_file = tarfile.open(output_path)
tar_file.extractall(output_path_extracted)
tar_file.close()
return output_path_extracted
return output_path
def split_s3_path(url): def split_s3_path(url):
"""Split a full s3 path into the bucket name and path.""" """Split a full s3 path into the bucket name and path."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment