"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e577bd0f13e1820650810f6864253d70dc76ce08"
Unverified Commit 6f84531e authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

offline mode for firewalled envs (part 2) (#10569)

* more readable test

* add all the missing places

* one more nltk

* better exception check

* revert
parent 54693694
...@@ -35,6 +35,7 @@ from .file_utils import ( ...@@ -35,6 +35,7 @@ from .file_utils import (
cached_path, cached_path,
hf_bucket_url, hf_bucket_url,
is_flax_available, is_flax_available,
is_offline_mode,
is_remote_url, is_remote_url,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
...@@ -342,6 +343,10 @@ class PreTrainedFeatureExtractor: ...@@ -342,6 +343,10 @@ class PreTrainedFeatureExtractor:
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
......
...@@ -1105,6 +1105,10 @@ def cached_path( ...@@ -1105,6 +1105,10 @@ def cached_path(
if isinstance(cache_dir, Path): if isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
if is_remote_url(url_or_filename): if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
output_path = get_from_cache( output_path = get_from_cache(
......
...@@ -28,7 +28,7 @@ from flax.traverse_util import flatten_dict, unflatten_dict ...@@ -28,7 +28,7 @@ from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
from .utils import logging from .utils import logging
...@@ -229,6 +229,10 @@ class FlaxPreTrainedModel(ABC): ...@@ -229,6 +229,10 @@ class FlaxPreTrainedModel(ABC):
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path config_path = config if config is not None else pretrained_model_name_or_path
......
...@@ -36,6 +36,7 @@ from .file_utils import ( ...@@ -36,6 +36,7 @@ from .file_utils import (
ModelOutput, ModelOutput,
cached_path, cached_path,
hf_bucket_url, hf_bucket_url,
is_offline_mode,
is_remote_url, is_remote_url,
) )
from .generation_tf_utils import TFGenerationMixin from .generation_tf_utils import TFGenerationMixin
...@@ -1151,6 +1152,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -1151,6 +1152,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None) mirror = kwargs.pop("mirror", None)
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path config_path = config if config is not None else pretrained_model_name_or_path
......
...@@ -27,20 +27,37 @@ class OfflineTests(TestCasePlus): ...@@ -27,20 +27,37 @@ class OfflineTests(TestCasePlus):
# while running an external program # while running an external program
# python one-liner segments # python one-liner segments
load = "from transformers import BertConfig, BertModel, BertTokenizer;"
run = "mname = 'lysandre/tiny-bert-random'; BertConfig.from_pretrained(mname) and BertModel.from_pretrained(mname) and BertTokenizer.from_pretrained(mname);" # this must be loaded before socket.socket is monkey-patched
mock = 'import socket; exec("def offline_socket(*args, **kwargs): raise socket.error(\\"Offline mode is enabled.\\")"); socket.socket = offline_socket;' load = """
from transformers import BertConfig, BertModel, BertTokenizer
"""
run = """
mname = "lysandre/tiny-bert-random"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
print("success")
"""
mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
socket.socket = offline_socket
"""
# baseline - just load from_pretrained with normal network # baseline - just load from_pretrained with normal network
cmd = [sys.executable, "-c", f"{load} {run}"] cmd = [sys.executable, "-c", "\n".join([load, run])]
# should succeed # should succeed
env = self.get_env() env = self.get_env()
result = subprocess.run(cmd, env=env, check=False, capture_output=True) result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr) self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode())
# next emulate no network # next emulate no network
cmd = [sys.executable, "-c", f"{load} {mock} {run}"] cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
# should normally fail as it will fail to lookup the model files w/o the network # should normally fail as it will fail to lookup the model files w/o the network
env["TRANSFORMERS_OFFLINE"] = "0" env["TRANSFORMERS_OFFLINE"] = "0"
...@@ -51,3 +68,4 @@ class OfflineTests(TestCasePlus): ...@@ -51,3 +68,4 @@ class OfflineTests(TestCasePlus):
env["TRANSFORMERS_OFFLINE"] = "1" env["TRANSFORMERS_OFFLINE"] = "1"
result = subprocess.run(cmd, env=env, check=False, capture_output=True) result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr) self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment