Commit 0e4cc050 authored by Sergey Mironov's avatar Sergey Mironov
Browse files

Add support for resumable downloads for HTTP protocol.

parent 0e64fec1
...@@ -92,6 +92,9 @@ class AutoConfig(object): ...@@ -92,6 +92,9 @@ class AutoConfig(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
......
...@@ -93,6 +93,9 @@ class PretrainedConfig(object): ...@@ -93,6 +93,9 @@ class PretrainedConfig(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -119,6 +122,7 @@ class PretrainedConfig(object): ...@@ -119,6 +122,7 @@ class PretrainedConfig(object):
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
...@@ -130,7 +134,8 @@ class PretrainedConfig(object): ...@@ -130,7 +134,8 @@ class PretrainedConfig(object):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
......
...@@ -22,6 +22,7 @@ from botocore.config import Config ...@@ -22,6 +22,7 @@ from botocore.config import Config
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import requests import requests
from tqdm import tqdm from tqdm import tqdm
from contextlib import contextmanager
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
...@@ -152,7 +153,7 @@ def filename_to_url(filename, cache_dir=None): ...@@ -152,7 +153,7 @@ def filename_to_url(filename, cache_dir=None):
return url, etag return url, etag
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
...@@ -161,6 +162,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -161,6 +162,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
Args: Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir. force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
...@@ -173,7 +175,9 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N ...@@ -173,7 +175,9 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
if parsed.scheme in ('http', 'https', 's3'): if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) return get_from_cache(url_or_filename, cache_dir=cache_dir,
force_download=force_download, proxies=proxies,
resume_download=resume_download)
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
return url_or_filename return url_or_filename
...@@ -234,19 +238,22 @@ def s3_get(url, temp_file, proxies=None): ...@@ -234,19 +238,22 @@ def s3_get(url, temp_file, proxies=None):
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file, proxies=None): def http_get(url, temp_file, proxies=None, resume_size=0):
req = requests.get(url, stream=True, proxies=proxies) headers={'Range':'bytes=%d-'%(resume_size,)} if resume_size > 0 else None
content_length = req.headers.get('Content-Length') response = requests.get(url, stream=True, proxies=proxies, headers=headers)
total = int(content_length) if content_length is not None else None if response.status_code == 416: # Range not satisfiable
progress = tqdm(unit="B", total=total) return
for chunk in req.iter_content(chunk_size=1024): content_length = response.headers.get('Content-Length')
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, initial=resume_size)
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) progress.update(len(chunk))
temp_file.write(chunk) temp_file.write(chunk)
progress.close() progress.close()
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10): def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
...@@ -289,17 +296,35 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -289,17 +296,35 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
if matching_files: if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1]) cache_path = os.path.join(cache_dir, matching_files[-1])
if resume_download:
incomplete_path = cache_path + '.incomplete'
@contextmanager
def _resumable_file_manager():
with open(incomplete_path,'a+b') as f:
yield f
os.remove(incomplete_path)
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else:
temp_file_manager = tempfile.NamedTemporaryFile
resume_size = 0
if not os.path.exists(cache_path) or force_download: if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file: with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object # GET file object
if url.startswith("s3://"): if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies) s3_get(url, temp_file, proxies=proxies)
else: else:
http_get(url, temp_file, proxies=proxies) http_get(url, temp_file, proxies=proxies, resume_size=resume_size)
# we are copying the file before closing it, so flush to avoid truncation # we are copying the file before closing it, so flush to avoid truncation
temp_file.flush() temp_file.flush()
......
...@@ -112,6 +112,9 @@ class AutoModel(object): ...@@ -112,6 +112,9 @@ class AutoModel(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -237,6 +240,8 @@ class AutoModelWithLMHead(object): ...@@ -237,6 +240,8 @@ class AutoModelWithLMHead(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
...@@ -357,6 +362,9 @@ class AutoModelForSequenceClassification(object): ...@@ -357,6 +362,9 @@ class AutoModelForSequenceClassification(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
......
...@@ -109,6 +109,9 @@ class TFAutoModel(object): ...@@ -109,6 +109,9 @@ class TFAutoModel(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object): ...@@ -237,6 +240,9 @@ class TFAutoModelWithLMHead(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object): ...@@ -360,6 +366,9 @@ class TFAutoModelForSequenceClassification(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -472,6 +481,9 @@ class TFAutoModelForQuestionAnswering(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
......
...@@ -176,6 +176,9 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -176,6 +176,9 @@ class TFPreTrainedModel(tf.keras.Model):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -201,6 +204,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -201,6 +204,7 @@ class TFPreTrainedModel(tf.keras.Model):
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
from_pt = kwargs.pop('from_pt', False) from_pt = kwargs.pop('from_pt', False)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
# Load config # Load config
...@@ -209,6 +213,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -209,6 +213,7 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_name_or_path, *model_args, pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True, cache_dir=cache_dir, return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download,
**kwargs **kwargs
) )
else: else:
...@@ -236,7 +241,8 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -236,7 +241,8 @@ class TFPreTrainedModel(tf.keras.Model):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
resume_download=resume_download, proxies=proxies)
except EnvironmentError as e: except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error( logger.error(
......
...@@ -246,6 +246,9 @@ class PreTrainedModel(nn.Module): ...@@ -246,6 +246,9 @@ class PreTrainedModel(nn.Module):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -275,6 +278,7 @@ class PreTrainedModel(nn.Module): ...@@ -275,6 +278,7 @@ class PreTrainedModel(nn.Module):
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', False) from_tf = kwargs.pop('from_tf', False)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False) output_loading_info = kwargs.pop('output_loading_info', False)
...@@ -284,6 +288,7 @@ class PreTrainedModel(nn.Module): ...@@ -284,6 +288,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path, *model_args, pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True, cache_dir=cache_dir, return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download,
**kwargs **kwargs
) )
else: else:
...@@ -315,7 +320,8 @@ class PreTrainedModel(nn.Module): ...@@ -315,7 +320,8 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained weights.".format( msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
......
...@@ -87,6 +87,9 @@ class AutoTokenizer(object): ...@@ -87,6 +87,9 @@ class AutoTokenizer(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if they exists. Force to (re-)download the vocabulary files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
......
...@@ -251,6 +251,9 @@ class PreTrainedTokenizer(object): ...@@ -251,6 +251,9 @@ class PreTrainedTokenizer(object):
force_download: (`optional`) boolean, default False: force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if they exists. Force to (re-)download the vocabulary files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
...@@ -286,6 +289,7 @@ class PreTrainedTokenizer(object): ...@@ -286,6 +289,7 @@ class PreTrainedTokenizer(object):
def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
...@@ -352,7 +356,7 @@ class PreTrainedTokenizer(object): ...@@ -352,7 +356,7 @@ class PreTrainedTokenizer(object):
if file_path is None: if file_path is None:
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
else: else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies) resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
msg = "Couldn't reach server at '{}' to download vocabulary files." msg = "Couldn't reach server at '{}' to download vocabulary files."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment