Unverified Commit a143d947 authored by Bram Vanroy's avatar Bram Vanroy Committed by GitHub
Browse files

Add local_files_only parameter to pretrained items (#2930)

* Add disable_outgoing to pretrained items

Setting disable_outgoing=True disables outgonig traffic:
- etags are not looked up
- models are not downloaded

* parameter name change

* Remove forgotten print
parent 286d1ec7
...@@ -198,6 +198,7 @@ class PretrainedConfig(object): ...@@ -198,6 +198,7 @@ class PretrainedConfig(object):
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
if pretrained_config_archive_map is None: if pretrained_config_archive_map is None:
pretrained_config_archive_map = cls.pretrained_config_archive_map pretrained_config_archive_map = cls.pretrained_config_archive_map
...@@ -219,6 +220,7 @@ class PretrainedConfig(object): ...@@ -219,6 +220,7 @@ class PretrainedConfig(object):
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only,
) )
# Load config dict # Load config dict
if resolved_config_file is None: if resolved_config_file is None:
......
...@@ -214,6 +214,7 @@ def cached_path( ...@@ -214,6 +214,7 @@ def cached_path(
user_agent=None, user_agent=None,
extract_compressed_file=False, extract_compressed_file=False,
force_extract=False, force_extract=False,
local_files_only=False,
) -> Optional[str]: ) -> Optional[str]:
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
...@@ -250,6 +251,7 @@ def cached_path( ...@@ -250,6 +251,7 @@ def cached_path(
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
user_agent=user_agent, user_agent=user_agent,
local_files_only=local_files_only,
) )
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
...@@ -378,7 +380,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): ...@@ -378,7 +380,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
def get_from_cache( def get_from_cache(
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None url,
cache_dir=None,
force_download=False,
proxies=None,
etag_timeout=10,
resume_download=False,
user_agent=None,
local_files_only=False,
) -> Optional[str]: ) -> Optional[str]:
""" """
Given a URL, look for the corresponding file in the local cache. Given a URL, look for the corresponding file in the local cache.
...@@ -395,18 +404,19 @@ def get_from_cache( ...@@ -395,18 +404,19 @@ def get_from_cache(
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
# Get eTag to add to filename, if it exists. etag = None
if url.startswith("s3://"): if not local_files_only:
etag = s3_etag(url, proxies=proxies) # Get eTag to add to filename, if it exists.
else: if url.startswith("s3://"):
try: etag = s3_etag(url, proxies=proxies)
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) else:
if response.status_code != 200: try:
etag = None response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
else: if response.status_code == 200:
etag = response.headers.get("ETag") etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout): except (EnvironmentError, requests.exceptions.Timeout):
etag = None # etag is already None
pass
filename = url_to_filename(url, etag) filename = url_to_filename(url, etag)
...@@ -427,6 +437,15 @@ def get_from_cache( ...@@ -427,6 +437,15 @@ def get_from_cache(
if len(matching_files) > 0: if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1]) return os.path.join(cache_dir, matching_files[-1])
else: else:
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
raise ValueError(
"Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
return None return None
# From now on, etag is not None. # From now on, etag is not None.
......
...@@ -376,6 +376,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -376,6 +376,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
...@@ -388,6 +389,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -388,6 +389,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only,
**kwargs, **kwargs,
) )
else: else:
...@@ -435,6 +437,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -435,6 +437,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only,
) )
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
......
...@@ -395,6 +395,7 @@ class PreTrainedTokenizer(object): ...@@ -395,6 +395,7 @@ class PreTrainedTokenizer(object):
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
...@@ -462,6 +463,7 @@ class PreTrainedTokenizer(object): ...@@ -462,6 +463,7 @@ class PreTrainedTokenizer(object):
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only,
) )
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment