test_offline.py 7.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
import sys
17
from typing import Tuple
18

19
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
20
21
22
23
24
25
26
27
28
29
30
from transformers.testing_utils import TestCasePlus, require_torch


class OfflineTests(TestCasePlus):
    @require_torch
    def test_offline_mode(self):
        # this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before
        # `transformers` is loaded, and it's too late for inside pytest - so we are changing it
        # while running an external program

        # python one-liner segments
31
32
33

        # this must be loaded before socket.socket is monkey-patched
        load = """
34
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
35
36
37
        """

        run = """
38
mname = "hf-internal-testing/tiny-random-bert"
39
40
41
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
42
pipe = pipeline(task="fill-mask", model=mname)
43
44
45
46
47
print("success")
        """

        mock = """
import socket
48
def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet")
49
50
socket.socket = offline_socket
        """
51

52
53
54
55
56
57
58
        # Force fetching the files so that we can use the cache
        mname = "hf-internal-testing/tiny-random-bert"
        BertConfig.from_pretrained(mname)
        BertModel.from_pretrained(mname)
        BertTokenizer.from_pretrained(mname)
        pipeline(task="fill-mask", model=mname)

59
        # baseline - just load from_pretrained with normal network
60
        # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
61
62
        stdout, _ = self._execute_with_env(load, run, mock, TRANSFORMERS_OFFLINE="1")
        self.assertIn("success", stdout)
63

64
65
66
67
68
69
70
    @require_torch
    def test_offline_mode_no_internet(self):
        # python one-liner segments
        # this must be loaded before socket.socket is monkey-patched
        load = """
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
        """
71

72
73
74
75
76
77
78
79
        run = """
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipe = pipeline(task="fill-mask", model=mname)
print("success")
        """
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet")
socket.socket = offline_socket
        """

        # Force fetching the files so that we can use the cache
        mname = "hf-internal-testing/tiny-random-bert"
        BertConfig.from_pretrained(mname)
        BertModel.from_pretrained(mname)
        BertTokenizer.from_pretrained(mname)
        pipeline(task="fill-mask", model=mname)

        # baseline - just load from_pretrained with normal network
        # should succeed
96
97
        stdout, _ = self._execute_with_env(load, run, mock)
        self.assertIn("success", stdout)
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    @require_torch
    def test_offline_mode_sharded_checkpoint(self):
        # this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before
        # `transformers` is loaded, and it's too late for inside pytest - so we are changing it
        # while running an external program

        # python one-liner segments

        # this must be loaded before socket.socket is monkey-patched
        load = """
from transformers import BertConfig, BertModel, BertTokenizer
        """

        run = """
mname = "hf-internal-testing/tiny-random-bert-sharded"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
print("success")
        """

        mock = """
import socket
121
def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled")
122
123
124
125
socket.socket = offline_socket
        """

        # baseline - just load from_pretrained with normal network
126
        # should succeed
127
128
        stdout, _ = self._execute_with_env(load, run)
        self.assertIn("success", stdout)
129
130

        # next emulate no network
131
        # Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
132
        # self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="0")
133
134

        # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
135
136
        stdout, _ = self._execute_with_env(load, mock, run, TRANSFORMERS_OFFLINE="1")
        self.assertIn("success", stdout)
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    @require_torch
    def test_offline_mode_pipeline_exception(self):
        load = """
from transformers import pipeline
        """
        run = """
mname = "hf-internal-testing/tiny-random-bert"
pipe = pipeline(model=mname)
        """

        mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
socket.socket = offline_socket
        """
153
154

        _, stderr = self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="1")
155
        self.assertIn(
156
            "You cannot infer task automatically within `pipeline` when using offline mode",
157
            stderr.replace("\n", ""),
158
        )
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    @require_torch
    def test_offline_model_dynamic_model(self):
        load = """
from transformers import AutoModel
        """
        run = """
mname = "hf-internal-testing/test_dynamic_model"
AutoModel.from_pretrained(mname, trust_remote_code=True)
print("success")
        """

        # baseline - just load from_pretrained with normal network
        # should succeed
173
174
        stdout, _ = self._execute_with_env(load, run)
        self.assertIn("success", stdout)
175
176

        # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1")
        self.assertIn("success", stdout)

    def test_is_offline_mode(self):
        """
        Test `_is_offline_mode` helper (should respect both HF_HUB_OFFLINE and legacy TRANSFORMERS_OFFLINE env vars)
        """
        load = "from transformers.utils import is_offline_mode"
        run = "print(is_offline_mode())"

        stdout, _ = self._execute_with_env(load, run)
        self.assertIn("False", stdout)

        stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1")
        self.assertIn("True", stdout)

        stdout, _ = self._execute_with_env(load, run, HF_HUB_OFFLINE="1")
        self.assertIn("True", stdout)

    def _execute_with_env(self, *commands: Tuple[str, ...], should_fail: bool = False, **env) -> Tuple[str, str]:
        """Execute Python code with a given environment and return the stdout/stderr as strings.

        If `should_fail=True`, the command is expected to fail. Otherwise, it should succeed.
        Environment variables can be passed as keyword arguments.
        """
        # Build command
        cmd = [sys.executable, "-c", "\n".join(commands)]

        # Configure env
        new_env = self.get_env()
        new_env.update(env)

        # Run command
        result = subprocess.run(cmd, env=new_env, check=False, capture_output=True)

        # Check execution
        if should_fail:
            self.assertNotEqual(result.returncode, 0, result.stderr)
        else:
            self.assertEqual(result.returncode, 0, result.stderr)

        # Return output
        return result.stdout.decode(), result.stderr.decode()