Unverified Commit 0270d44f authored by Leandro von Werra's avatar Leandro von Werra Committed by GitHub
Browse files

Context managers (#13900)

* add `ContextManagers` for lists of contexts

* fix import sorting

* add `ContextManagers` tests
parent f875fb0e
...@@ -30,14 +30,14 @@ import tarfile ...@@ -30,14 +30,14 @@ import tarfile
import tempfile import tempfile
import types import types
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from contextlib import contextmanager from contextlib import ExitStack, contextmanager
from dataclasses import fields from dataclasses import fields
from enum import Enum from enum import Enum
from functools import partial, wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, ContextManager, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from zipfile import ZipFile, is_zipfile from zipfile import ZipFile, is_zipfile
...@@ -2365,3 +2365,21 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -2365,3 +2365,21 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{username}/{model_id}" return f"{username}/{model_id}"
else: else:
return f"{organization}/{model_id}" return f"{organization}/{model_id}"
class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
in the `fastcore` library.
"""
def __init__(self, context_managers: List[ContextManager]):
self.context_managers = context_managers
self.stack = ExitStack()
def __enter__(self):
for context_manager in self.context_managers:
self.stack.enter_context(context_manager)
def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import importlib import importlib
import io
import unittest import unittest
import requests import requests
...@@ -20,7 +22,14 @@ import transformers ...@@ -20,7 +22,14 @@ import transformers
# Try to import everything from transformers to ensure every object can be loaded. # Try to import everything from transformers to ensure every object can be loaded.
from transformers import * # noqa F406 from transformers import * # noqa F406
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url from transformers.file_utils import (
CONFIG_NAME,
WEIGHTS_NAME,
ContextManagers,
filename_to_url,
get_from_cache,
hf_bucket_url,
)
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
...@@ -40,6 +49,21 @@ PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d ...@@ -40,6 +49,21 @@ PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes # Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
# Dummy contexts to test `ContextManagers`
@contextlib.contextmanager
def context_en():
print("Welcome!")
yield
print("Bye!")
@contextlib.contextmanager
def context_fr():
print("Bonjour!")
yield
print("Au revoir!")
def test_module_spec(): def test_module_spec():
assert transformers.__spec__ is not None assert transformers.__spec__ is not None
assert importlib.util.find_spec("transformers") is not None assert importlib.util.find_spec("transformers") is not None
...@@ -85,3 +109,26 @@ class GetFromCacheTests(unittest.TestCase): ...@@ -85,3 +109,26 @@ class GetFromCacheTests(unittest.TestCase):
filepath = get_from_cache(url, force_download=True) filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath) metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"')) self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
class ContextManagerTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_no_context(self, mock_stdout):
with ContextManagers([]):
print("Transformers are awesome!")
# The print statement adds a new line at the end of the output
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_one_context(self, mock_stdout):
with ContextManagers([context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_two_context(self, mock_stdout):
with ContextManagers([context_fr(), context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English and French welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment