Commit 09e1b0b4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

remove transformers dependency

parent 5a784f98
...@@ -24,18 +24,19 @@ import re ...@@ -24,18 +24,19 @@ import re
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from requests import HTTPError from requests import HTTPError
from transformers.utils import ( from huggingface_hub import hf_hub_download
from .utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DIFFUSERS_CACHE,
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
cached_path,
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging, logging,
) )
from . import __version__ from . import __version__
...@@ -90,12 +91,11 @@ class ConfigMixin: ...@@ -90,12 +91,11 @@ class ConfigMixin:
self.to_json_file(output_config_file) self.to_json_file(output_config_file)
logger.info(f"ConfigMixinuration saved in {output_config_file}") logger.info(f"ConfigMixinuration saved in {output_config_file}")
@classmethod @classmethod
def get_config_dict( def get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
...@@ -105,27 +105,24 @@ class ConfigMixin: ...@@ -105,27 +105,24 @@ class ConfigMixin:
user_agent = {"file_type": "config"} user_agent = {"file_type": "config"}
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
configuration_file = cls.config_name
if os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, configuration_file) config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
# Load from a PyTorch checkpoint
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
else: else:
config_file = hf_bucket_url( raise EnvironmentError(
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
) )
else:
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_config_file = cached_path( config_file = hf_hub_download(
config_file, pretrained_model_name_or_path,
filename=cls.config_name,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
...@@ -150,7 +147,7 @@ class ConfigMixin: ...@@ -150,7 +147,7 @@ class ConfigMixin:
) )
except EntryNotFoundError: except EntryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}." f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
) )
except HTTPError as err: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
...@@ -160,7 +157,7 @@ class ConfigMixin: ...@@ -160,7 +157,7 @@ class ConfigMixin:
raise EnvironmentError( raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in" f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory" f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the" f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." " library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
) )
except EnvironmentError: except EnvironmentError:
...@@ -168,22 +165,17 @@ class ConfigMixin: ...@@ -168,22 +165,17 @@ class ConfigMixin:
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {configuration_file} file" f"containing a {cls.config_name} file"
) )
try: try:
# Load config dict # Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file) config_dict = cls._dict_from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError( raise EnvironmentError(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." f"It looks like the config file at '{config_file}' is not a valid JSON file."
) )
if resolved_config_file == config_file:
logger.info(f"loading configuration file {config_file}")
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
return config_dict return config_dict
@classmethod @classmethod
...@@ -199,9 +191,7 @@ class ConfigMixin: ...@@ -199,9 +191,7 @@ class ConfigMixin:
# use value from config dict # use value from config dict
init_dict[key] = config_dict.pop(key) init_dict[key] = config_dict.pop(key)
unused_kwargs = config_dict.update(kwargs) unused_kwargs = config_dict.update(kwargs)
passed_keys = set(init_dict.keys()) passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.warn( logger.warn(
......
...@@ -22,16 +22,8 @@ import sys ...@@ -22,16 +22,8 @@ import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info from huggingface_hub import cached_download
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
from transformers.utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
logging,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -219,7 +211,7 @@ def get_cached_module_file( ...@@ -219,7 +211,7 @@ def get_cached_module_file(
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_module_file = cached_path( resolved_module_file = cached_download(
module_file_or_url, module_file_or_url,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
...@@ -237,7 +229,7 @@ def get_cached_module_file( ...@@ -237,7 +229,7 @@ def get_cached_module_file(
modules_needed = check_imports(resolved_module_file) modules_needed = check_imports(resolved_module_file)
# Now we move the module inside our cached dynamic modules. # Now we move the module inside our cached dynamic modules.
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule) create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule submodule_path = Path(HF_MODULES_CACHE) / full_submodule
# We always copy local files (we could hash the file to see if there was a change, and give them the name of # We always copy local files (we could hash the file to see if there was a change, and give them the name of
......
...@@ -21,18 +21,15 @@ import torch ...@@ -21,18 +21,15 @@ import torch
from torch import Tensor, device from torch import Tensor, device
from requests import HTTPError from requests import HTTPError
from huggingface_hub import hf_hub_download
# CHANGE to diffusers.utils from .utils import (
from transformers.utils import (
CONFIG_NAME, CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
cached_path,
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging, logging,
) )
...@@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module): ...@@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module):
</Tip> </Tip>
""" """
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
...@@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module): ...@@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module):
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path config_path = pretrained_model_name_or_path
model, unused_kwargs = cls.from_config( model, unused_kwargs = cls.from_config(
...@@ -353,24 +345,17 @@ class ModelMixin(torch.nn.Module): ...@@ -353,24 +345,17 @@ class ModelMixin(torch.nn.Module):
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else: else:
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
) )
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else: else:
filename = WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror
)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_archive_file = cached_path( model_file = hf_hub_download(
archive_file, pretrained_model_name_or_path,
filename=WEIGHTS_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
...@@ -394,7 +379,7 @@ class ModelMixin(torch.nn.Module): ...@@ -394,7 +379,7 @@ class ModelMixin(torch.nn.Module):
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
) )
except EntryNotFoundError: except EntryNotFoundError:
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {filename}.") raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.")
except HTTPError as err: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
...@@ -415,17 +400,12 @@ class ModelMixin(torch.nn.Module): ...@@ -415,17 +400,12 @@ class ModelMixin(torch.nn.Module):
f"containing a file named {WEIGHTS_NAME}" f"containing a file named {WEIGHTS_NAME}"
) )
if resolved_archive_file == archive_file:
logger.info(f"loading weights file {archive_file}")
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
# restore default dtype # restore default dtype
state_dict = load_state_dict(resolved_archive_file) state_dict = load_state_dict(model_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model, model,
state_dict, state_dict,
resolved_archive_file, model_file,
pretrained_model_name_or_path, pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes, ignore_mismatched_sizes=ignore_mismatched_sizes,
) )
......
...@@ -19,8 +19,7 @@ import os ...@@ -19,8 +19,7 @@ import os
from typing import Optional, Union from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
# CHANGE to diffusers.utils from .utils import logging, DIFFUSERS_CACHE
from transformers.utils import logging
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
...@@ -56,7 +55,6 @@ class DiffusionPipeline(ConfigMixin): ...@@ -56,7 +55,6 @@ class DiffusionPipeline(ConfigMixin):
register_dict = {name: (library, class_name)} register_dict = {name: (library, class_name)}
# save model index config # save model index config
self.register(**register_dict) self.register(**register_dict)
...@@ -94,9 +92,29 @@ class DiffusionPipeline(ConfigMixin): ...@@ -94,9 +92,29 @@ class DiffusionPipeline(ConfigMixin):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Add docstrings
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
cached_folder = snapshot_download(pretrained_model_name_or_path) cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
output_loading_info=output_loading_info,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
...@@ -110,7 +128,6 @@ class DiffusionPipeline(ConfigMixin): ...@@ -110,7 +128,6 @@ class DiffusionPipeline(ConfigMixin):
else: else:
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {} init_kwargs = {}
......
#!/usr/bin/env python
# coding=utf-8
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from requests.exceptions import HTTPError
hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
default_cache_path = os.path.join(hf_cache_home, "diffusers")
CONFIG_NAME = "config.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
not have access to.
"""
class EntryNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities."""
import logging
import os
import sys
import threading
from logging import CRITICAL # NOQA
from logging import DEBUG # NOQA
from logging import ERROR # NOQA
from logging import FATAL # NOQA
from logging import INFO # NOQA
from logging import NOTSET # NOQA
from logging import WARN # NOQA
from logging import WARNING # NOQA
from typing import Optional
from tqdm import auto as tqdm_lib
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
_default_log_level = logging.WARNING
_tqdm_active = True
def _get_default_logging_level():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level`
"""
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> logging.Logger:
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
global _default_handler
with _lock:
if _default_handler:
# This library has already configured the library root logger.
return
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
_default_handler.flush = sys.stderr.flush
# Apply our default configuration to the library root logger.
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def _reset_library_root_logger() -> None:
global _default_handler
with _lock:
if not _default_handler:
return
library_root_logger = _get_library_root_logger()
library_root_logger.removeHandler(_default_handler)
library_root_logger.setLevel(logging.NOTSET)
_default_handler = None
def get_log_levels_dict():
return log_levels
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Return a logger with the specified name.
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
"""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
def get_verbosity() -> int:
"""
Return the current level for the 🤗 Transformers's root logger as an int.
Returns:
`int`: The logging level.
<Tip>
🤗 Transformers has following logging levels:
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- 40: `diffusers.logging.ERROR`
- 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- 20: `diffusers.logging.INFO`
- 10: `diffusers.logging.DEBUG`
</Tip>"""
_configure_library_root_logger()
return _get_library_root_logger().getEffectiveLevel()
def set_verbosity(verbosity: int) -> None:
"""
Set the verbosity level for the 🤗 Transformers's root logger.
Args:
verbosity (`int`):
Logging level, e.g., one of:
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- `diffusers.logging.ERROR`
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- `diffusers.logging.INFO`
- `diffusers.logging.DEBUG`
"""
_configure_library_root_logger()
_get_library_root_logger().setLevel(verbosity)
def set_verbosity_info():
"""Set the verbosity to the `INFO` level."""
return set_verbosity(INFO)
def set_verbosity_warning():
"""Set the verbosity to the `WARNING` level."""
return set_verbosity(WARNING)
def set_verbosity_debug():
"""Set the verbosity to the `DEBUG` level."""
return set_verbosity(DEBUG)
def set_verbosity_error():
"""Set the verbosity to the `ERROR` level."""
return set_verbosity(ERROR)
def disable_default_handler() -> None:
"""Disable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().removeHandler(_default_handler)
def enable_default_handler() -> None:
"""Enable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().addHandler(_default_handler)
def add_handler(handler: logging.Handler) -> None:
"""adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
"""removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None and handler not in _get_library_root_logger().handlers
_get_library_root_logger().removeHandler(handler)
def disable_propagation() -> None:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = False
def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = True
def enable_explicit_format() -> None:
"""
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
```
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
```
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
handler.setFormatter(formatter)
def reset_format() -> None:
"""
Resets the formatting for HuggingFace Transformers's loggers.
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
handler.setFormatter(None)
def warning_advice(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False)
if no_advisory_warnings:
return
self.warning(*args, **kwargs)
logging.Logger.warning_advice = warning_advice
class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
self._iterator = args[0] if args else None
def __iter__(self):
return iter(self._iterator)
def __getattr__(self, _):
"""Return empty function."""
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
return
return empty_fn
def __enter__(self):
return self
def __exit__(self, type_, value, traceback):
return
class _tqdm_cls:
def __call__(self, *args, **kwargs):
if _tqdm_active:
return tqdm_lib.tqdm(*args, **kwargs)
else:
return EmptyTqdm(*args, **kwargs)
def set_lock(self, *args, **kwargs):
self._lock = None
if _tqdm_active:
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
def get_lock(self):
if _tqdm_active:
return tqdm_lib.tqdm.get_lock()
tqdm = _tqdm_cls()
def is_progress_bar_enabled() -> bool:
"""Return a boolean indicating whether tqdm progress bars are enabled."""
global _tqdm_active
return bool(_tqdm_active)
def enable_progress_bar():
"""Enable tqdm progress bar."""
global _tqdm_active
_tqdm_active = True
def disable_progress_bar():
"""Disable tqdm progress bar."""
global _tqdm_active
_tqdm_active = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment