Commit a4c9338b authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Prevent parallel downloads of the same file with a lock.

Since the file is written to the filesystem, a filesystem lock is the
way to go here. Add a dependency on the third-party filelock library to
get cross-platform functionality.
parent b670c266
...@@ -59,6 +59,7 @@ setup( ...@@ -59,6 +59,7 @@ setup(
"tests.*", "tests"]), "tests.*", "tests"]),
install_requires=['numpy', install_requires=['numpy',
'boto3', 'boto3',
'filelock',
'requests', 'requests',
'tqdm', 'tqdm',
'regex != 2019.12.17', 'regex != 2019.12.17',
......
...@@ -24,6 +24,8 @@ from tqdm.auto import tqdm ...@@ -24,6 +24,8 @@ from tqdm.auto import tqdm
from contextlib import contextmanager from contextlib import contextmanager
from . import __version__ from . import __version__
from filelock import FileLock
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
...@@ -333,53 +335,60 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -333,53 +335,60 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# If we don't have a connection (etag is None) and can't identify the file # If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one # try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None: if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') matching_files = [
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) file
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*')
if not file.endswith('.json') and not file.endswith('.lock')
]
if matching_files: if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1]) cache_path = os.path.join(cache_dir, matching_files[-1])
if resume_download: # Prevent parallel downloads of the same file with a lock.
incomplete_path = cache_path + '.incomplete' lock_path = cache_path + '.lock'
@contextmanager with FileLock(lock_path):
def _resumable_file_manager():
with open(incomplete_path,'a+b') as f: if resume_download:
yield f incomplete_path = cache_path + '.incomplete'
temp_file_manager = _resumable_file_manager @contextmanager
if os.path.exists(incomplete_path): def _resumable_file_manager():
resume_size = os.stat(incomplete_path).st_size with open(incomplete_path,'a+b') as f:
yield f
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else: else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0 resume_size = 0
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) if etag is not None and (not os.path.exists(cache_path) or force_download):
resume_size = 0 # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
if etag is not None and (not os.path.exists(cache_path) or force_download): with temp_file_manager() as temp_file:
# Download to temporary file, then copy to cache dir once finished. logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file: # GET file object
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) if url.startswith("s3://"):
if resume_download:
# GET file object logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
if url.startswith("s3://"): s3_get(url, temp_file, proxies=proxies)
if resume_download: else:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
s3_get(url, temp_file, proxies=proxies)
else: # we are copying the file before closing it, so flush to avoid truncation
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) temp_file.flush()
# we are copying the file before closing it, so flush to avoid truncation logger.info("storing %s in cache at %s", url, cache_path)
temp_file.flush() os.rename(temp_file.name, cache_path)
logger.info("storing %s in cache at %s", url, cache_path) logger.info("creating metadata file for %s", cache_path)
os.rename(temp_file.name, cache_path) meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
logger.info("creating metadata file for %s", cache_path) with open(meta_path, 'w') as meta_file:
meta = {'url': url, 'etag': etag} output_string = json.dumps(meta)
meta_path = cache_path + '.json' if sys.version_info[0] == 2 and isinstance(output_string, str):
with open(meta_path, 'w') as meta_file: output_string = unicode(output_string, 'utf-8') # The beauty of python 2
output_string = json.dumps(meta) meta_file.write(output_string)
if sys.version_info[0] == 2 and isinstance(output_string, str):
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
meta_file.write(output_string)
return cache_path return cache_path
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment