"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b242d0f297aa87a0c8d99657a53691ece2dfe492"
Unverified Commit 25ddd91b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing offline mode for pipeline (when inferring task). (#21113)



* Fixing offline mode for pipeline (when inferring task).

* Update src/transformers/pipelines/__init__.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Updating test to reflect change in exception.

* Fixing offline mode.

* Clean.
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8896ebb9
...@@ -40,6 +40,7 @@ from ..tokenization_utils_fast import PreTrainedTokenizerFast ...@@ -40,6 +40,7 @@ from ..tokenization_utils_fast import PreTrainedTokenizerFast
from ..utils import ( from ..utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
is_kenlm_available, is_kenlm_available,
is_offline_mode,
is_pyctcdecode_available, is_pyctcdecode_available,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
...@@ -398,6 +399,8 @@ def get_supported_tasks() -> List[str]: ...@@ -398,6 +399,8 @@ def get_supported_tasks() -> List[str]:
def get_task(model: str, use_auth_token: Optional[str] = None) -> str: def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
if is_offline_mode():
raise RuntimeError(f"You cannot infer task automatically within `pipeline` when using offline mode")
try: try:
info = model_info(model, token=use_auth_token) info = model_info(model, token=use_auth_token)
except Exception as e: except Exception as e:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import subprocess import subprocess
import sys import sys
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
from transformers.testing_utils import TestCasePlus, require_torch from transformers.testing_utils import TestCasePlus, require_torch
...@@ -30,7 +31,7 @@ class OfflineTests(TestCasePlus): ...@@ -30,7 +31,7 @@ class OfflineTests(TestCasePlus):
# this must be loaded before socket.socket is monkey-patched # this must be loaded before socket.socket is monkey-patched
load = """ load = """
from transformers import BertConfig, BertModel, BertTokenizer from transformers import BertConfig, BertModel, BertTokenizer, pipeline
""" """
run = """ run = """
...@@ -38,34 +39,69 @@ mname = "hf-internal-testing/tiny-random-bert" ...@@ -38,34 +39,69 @@ mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname) BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname) BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname) BertTokenizer.from_pretrained(mname)
pipe = pipeline(task="fill-mask", model=mname)
print("success") print("success")
""" """
mock = """ mock = """
import socket import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet")
socket.socket = offline_socket socket.socket = offline_socket
""" """
# Force fetching the files so that we can use the cache
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipeline(task="fill-mask", model=mname)
# baseline - just load from_pretrained with normal network # baseline - just load from_pretrained with normal network
cmd = [sys.executable, "-c", "\n".join([load, run])] cmd = [sys.executable, "-c", "\n".join([load, run, mock])]
# should succeed # should succeed
env = self.get_env() env = self.get_env()
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
env["TRANSFORMERS_OFFLINE"] = "1"
result = subprocess.run(cmd, env=env, check=False, capture_output=True) result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr) self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode()) self.assertIn("success", result.stdout.decode())
# next emulate no network @require_torch
cmd = [sys.executable, "-c", "\n".join([load, mock, run])] def test_offline_mode_no_internet(self):
# python one-liner segments
# this must be loaded before socket.socket is monkey-patched
load = """
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
"""
# Doesn't fail anymore since the model is in the cache due to other tests, so commenting this. run = """
# env["TRANSFORMERS_OFFLINE"] = "0" mname = "hf-internal-testing/tiny-random-bert"
# result = subprocess.run(cmd, env=env, check=False, capture_output=True) BertConfig.from_pretrained(mname)
# self.assertEqual(result.returncode, 1, result.stderr) BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipe = pipeline(task="fill-mask", model=mname)
print("success")
"""
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files mock = """
env["TRANSFORMERS_OFFLINE"] = "1" import socket
def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet")
socket.socket = offline_socket
"""
# Force fetching the files so that we can use the cache
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipeline(task="fill-mask", model=mname)
# baseline - just load from_pretrained with normal network
cmd = [sys.executable, "-c", "\n".join([load, run, mock])]
# should succeed
env = self.get_env()
result = subprocess.run(cmd, env=env, check=False, capture_output=True) result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr) self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode()) self.assertIn("success", result.stdout.decode())
...@@ -93,7 +129,7 @@ print("success") ...@@ -93,7 +129,7 @@ print("success")
mock = """ mock = """
import socket import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled")
socket.socket = offline_socket socket.socket = offline_socket
""" """
...@@ -119,3 +155,27 @@ socket.socket = offline_socket ...@@ -119,3 +155,27 @@ socket.socket = offline_socket
result = subprocess.run(cmd, env=env, check=False, capture_output=True) result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr) self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode()) self.assertIn("success", result.stdout.decode())
@require_torch
def test_offline_mode_pipeline_exception(self):
load = """
from transformers import pipeline
"""
run = """
mname = "hf-internal-testing/tiny-random-bert"
pipe = pipeline(model=mname)
"""
mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
socket.socket = offline_socket
"""
env = self.get_env()
env["TRANSFORMERS_OFFLINE"] = "1"
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 1, result.stderr)
self.assertIn(
"You cannot infer task automatically within `pipeline` when using offline mode", result.stderr.decode()
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment