Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
09e1b0b4
Commit
09e1b0b4
authored
Jun 09, 2022
by
Patrick von Platen
Browse files
remove transformers dependency
parent
5a784f98
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
552 additions
and
181 deletions
+552
-181
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+71
-81
src/diffusers/dynamic_modules_utils.py
src/diffusers/dynamic_modules_utils.py
+4
-12
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+60
-80
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+25
-8
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+48
-0
src/diffusers/utils/logging.py
src/diffusers/utils/logging.py
+344
-0
No files found.
src/diffusers/configuration_utils.py
View file @
09e1b0b4
...
...
@@ -24,18 +24,19 @@ import re
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
requests
import
HTTPError
from
transformers.utils
import
(
from
huggingface_hub
import
hf_hub_download
from
.utils
import
(
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
DIFFUSERS_CACHE
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
,
logging
,
)
from
.
import
__version__
...
...
@@ -89,13 +90,12 @@ class ConfigMixin:
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
@
classmethod
def
get_config_dict
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
None
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
...
...
@@ -105,85 +105,77 @@ class ConfigMixin:
user_agent
=
{
"file_type"
:
"config"
}
if
is_offline_mode
()
and
not
local_files_only
:
logger
.
info
(
"Offline mode: forcing local_files_only=True"
)
local_files_only
=
True
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
else
:
configuration_file
=
cls
.
config_name
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
configuration_file
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)):
# Load from a PyTorch checkpoint
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)
else
:
config_file
=
hf_bucket_url
(
pretrained_model_name_or_path
,
filename
=
configuration_file
,
revision
=
revision
,
mirror
=
None
raise
EnvironmentError
(
f
"Error no file named
{
cls
.
config_name
}
found in directory
{
pretrained_model_name_or_path
}
."
)
else
:
try
:
# Load from URL or cache if already cached
config_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
filename
=
cls
.
config_name
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
try
:
# Load from URL or cache if already cached
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f
"model name. Check the model page at 'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for "
"available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
configuration_file
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it in"
f
" the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a directory"
f
" containing a
{
configuration_file
}
file.
\n
Checkout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load config for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a
{
configuration_file
}
file"
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f
"model name. Check the model page at 'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for "
"available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
cls
.
config_name
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it in"
f
" the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a directory"
f
" containing a
{
cls
.
config_name
}
file.
\n
Checkout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load config for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a
{
cls
.
config_name
}
file"
)
try
:
# Load config dict
config_dict
=
cls
.
_dict_from_json_file
(
resolved_
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
f
"It looks like the config file at '
{
resolved_
config_file
}
' is not a valid JSON file."
)
try
:
# Load config dict
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
f
"It looks like the config file at '
{
config_file
}
' is not a valid JSON file."
)
if
resolved_config_file
==
config_file
:
logger
.
info
(
f
"loading configuration file
{
config_file
}
"
)
else
:
logger
.
info
(
f
"loading configuration file
{
config_file
}
from cache at
{
resolved_config_file
}
"
)
return
config_dict
@
classmethod
...
...
@@ -199,9 +191,7 @@ class ConfigMixin:
# use value from config dict
init_dict
[
key
]
=
config_dict
.
pop
(
key
)
unused_kwargs
=
config_dict
.
update
(
kwargs
)
passed_keys
=
set
(
init_dict
.
keys
())
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warn
(
...
...
src/diffusers/dynamic_modules_utils.py
View file @
09e1b0b4
...
...
@@ -22,16 +22,8 @@ import sys
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Union
from
huggingface_hub
import
HfFolder
,
model_info
from
transformers.utils
import
(
HF_MODULES_CACHE
,
TRANSFORMERS_DYNAMIC_MODULE_NAME
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
logging
,
)
from
huggingface_hub
import
cached_download
from
.utils
import
HF_MODULES_CACHE
,
DIFFUSERS_DYNAMIC_MODULE_NAME
,
logging
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -219,7 +211,7 @@ def get_cached_module_file(
try
:
# Load from URL or cache if already cached
resolved_module_file
=
cached_
path
(
resolved_module_file
=
cached_
download
(
module_file_or_url
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
...
...
@@ -237,7 +229,7 @@ def get_cached_module_file(
modules_needed
=
check_imports
(
resolved_module_file
)
# Now we move the module inside our cached dynamic modules.
full_submodule
=
TRANSFORM
ERS_DYNAMIC_MODULE_NAME
+
os
.
path
.
sep
+
submodule
full_submodule
=
DIFFUS
ERS_DYNAMIC_MODULE_NAME
+
os
.
path
.
sep
+
submodule
create_dynamic_module
(
full_submodule
)
submodule_path
=
Path
(
HF_MODULES_CACHE
)
/
full_submodule
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
...
...
src/diffusers/modeling_utils.py
View file @
09e1b0b4
...
...
@@ -21,18 +21,15 @@ import torch
from
torch
import
Tensor
,
device
from
requests
import
HTTPError
from
huggingface_hub
import
hf_hub_download
# CHANGE to diffusers.utils
from
transformers.utils
import
(
from
.utils
import
(
CONFIG_NAME
,
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
,
logging
,
)
...
...
@@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module):
</Tip>
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
None
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
ignore_mismatched_sizes
=
kwargs
.
pop
(
"ignore_mismatched_sizes"
,
False
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
...
...
@@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
mirror
=
kwargs
.
pop
(
"mirror"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
if
is_offline_mode
()
and
not
local_files_only
:
logger
.
info
(
"Offline mode: forcing local_files_only=True"
)
local_files_only
=
True
# Load config if we don't provide a configuration
config_path
=
pretrained_model_name_or_path
model
,
unused_kwargs
=
cls
.
from_config
(
...
...
@@ -353,83 +345,71 @@ class ModelMixin(torch.nn.Module):
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)):
# Load from a PyTorch checkpoint
archive
_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
model
_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
raise
EnvironmentError
(
f
"Error no file named
{
WEIGHTS_NAME
}
found in directory
{
pretrained_model_name_or_path
}
."
)
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
else
:
filename
=
WEIGHTS_NAME
archive_file
=
hf_bucket_url
(
pretrained_model_name_or_path
,
filename
=
filename
,
revision
=
revision
,
mirror
=
mirror
)
try
:
# Load from URL or cache if already cached
model_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
filename
=
WEIGHTS_NAME
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
try
:
# Load from URL or cache if already cached
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
model_file
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it"
f
" in the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a"
f
" directory containing a file named
{
WEIGHTS_NAME
}
or"
"
\n
Checkout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load the model for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a file named
{
WEIGHTS_NAME
}
"
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
filename
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it"
f
" in the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a"
f
" directory containing a file named
{
WEIGHTS_NAME
}
or"
"
\n
Checkout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load the model for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a file named
{
WEIGHTS_NAME
}
"
# restore default dtype
state_dict
=
load_state_dict
(
model_file
)
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
state_dict
,
model_file
,
pretrained_model_name_or_path
,
ignore_mismatched_sizes
=
ignore_mismatched_sizes
,
)
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
f
"loading weights file
{
archive_file
}
"
)
else
:
logger
.
info
(
f
"loading weights file
{
archive_file
}
from cache at
{
resolved_archive_file
}
"
)
# restore default dtype
state_dict
=
load_state_dict
(
resolved_archive_file
)
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
state_dict
,
resolved_archive_file
,
pretrained_model_name_or_path
,
ignore_mismatched_sizes
=
ignore_mismatched_sizes
,
)
# Set model in evaluation mode to deactivate DropOut modules by default
model
.
eval
()
...
...
src/diffusers/pipeline_utils.py
View file @
09e1b0b4
...
...
@@ -19,8 +19,7 @@ import os
from
typing
import
Optional
,
Union
from
huggingface_hub
import
snapshot_download
# CHANGE to diffusers.utils
from
transformers.utils
import
logging
from
.utils
import
logging
,
DIFFUSERS_CACHE
from
.configuration_utils
import
ConfigMixin
from
.dynamic_modules_utils
import
get_class_from_dynamic_module
...
...
@@ -55,14 +54,13 @@ class DiffusionPipeline(ConfigMixin):
class_name
=
module
.
__class__
.
__name__
register_dict
=
{
name
:
(
library
,
class_name
)}
# save model index config
self
.
register
(
**
register_dict
)
# set models
setattr
(
self
,
name
,
module
)
register_dict
=
{
"_module"
:
self
.
__module__
.
split
(
"."
)[
-
1
]
+
".py"
}
self
.
register
(
**
register_dict
)
...
...
@@ -94,22 +92,41 @@ class DiffusionPipeline(ConfigMixin):
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
r
"""
Add docstrings
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
output_loading_info
=
kwargs
.
pop
(
"output_loading_info"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
)
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
output_loading_info
=
output_loading_info
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
)
else
:
cached_folder
=
pretrained_model_name_or_path
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
module
=
config_dict
[
"_module"
]
class_name_
=
config_dict
[
"_class_name"
]
if
class_name_
==
cls
.
__name__
:
pipeline_class
=
cls
else
:
pipeline_class
=
get_class_from_dynamic_module
(
cached_folder
,
module
,
class_name_
,
cached_folder
)
init_dict
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
...
...
src/diffusers/utils/__init__.py
0 → 100644
View file @
09e1b0b4
#!/usr/bin/env python
# coding=utf-8
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
requests.exceptions
import
HTTPError
hf_cache_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
"HF_HOME"
,
os
.
path
.
join
(
os
.
getenv
(
"XDG_CACHE_HOME"
,
"~/.cache"
),
"huggingface"
))
)
default_cache_path
=
os
.
path
.
join
(
hf_cache_home
,
"diffusers"
)
CONFIG_NAME
=
"config.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT
=
"https://huggingface.co"
DIFFUSERS_CACHE
=
default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME
=
"diffusers_modules"
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
class
RepositoryNotFoundError
(
HTTPError
):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
not have access to.
"""
class
EntryNotFoundError
(
HTTPError
):
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
class
RevisionNotFoundError
(
HTTPError
):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
src/diffusers/utils/logging.py
0 → 100644
View file @
09e1b0b4
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities."""
import
logging
import
os
import
sys
import
threading
from
logging
import
CRITICAL
# NOQA
from
logging
import
DEBUG
# NOQA
from
logging
import
ERROR
# NOQA
from
logging
import
FATAL
# NOQA
from
logging
import
INFO
# NOQA
from
logging
import
NOTSET
# NOQA
from
logging
import
WARN
# NOQA
from
logging
import
WARNING
# NOQA
from
typing
import
Optional
from
tqdm
import
auto
as
tqdm_lib
_lock
=
threading
.
Lock
()
_default_handler
:
Optional
[
logging
.
Handler
]
=
None
log_levels
=
{
"debug"
:
logging
.
DEBUG
,
"info"
:
logging
.
INFO
,
"warning"
:
logging
.
WARNING
,
"error"
:
logging
.
ERROR
,
"critical"
:
logging
.
CRITICAL
,
}
_default_log_level
=
logging
.
WARNING
_tqdm_active
=
True
def
_get_default_logging_level
():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level`
"""
env_level_str
=
os
.
getenv
(
"TRANSFORMERS_VERBOSITY"
,
None
)
if
env_level_str
:
if
env_level_str
in
log_levels
:
return
log_levels
[
env_level_str
]
else
:
logging
.
getLogger
().
warning
(
f
"Unknown option TRANSFORMERS_VERBOSITY=
{
env_level_str
}
, "
f
"has to be one of:
{
', '
.
join
(
log_levels
.
keys
())
}
"
)
return
_default_log_level
def
_get_library_name
()
->
str
:
return
__name__
.
split
(
"."
)[
0
]
def
_get_library_root_logger
()
->
logging
.
Logger
:
return
logging
.
getLogger
(
_get_library_name
())
def
_configure_library_root_logger
()
->
None
:
global
_default_handler
with
_lock
:
if
_default_handler
:
# This library has already configured the library root logger.
return
_default_handler
=
logging
.
StreamHandler
()
# Set sys.stderr as stream.
_default_handler
.
flush
=
sys
.
stderr
.
flush
# Apply our default configuration to the library root logger.
library_root_logger
=
_get_library_root_logger
()
library_root_logger
.
addHandler
(
_default_handler
)
library_root_logger
.
setLevel
(
_get_default_logging_level
())
library_root_logger
.
propagate
=
False
def
_reset_library_root_logger
()
->
None
:
global
_default_handler
with
_lock
:
if
not
_default_handler
:
return
library_root_logger
=
_get_library_root_logger
()
library_root_logger
.
removeHandler
(
_default_handler
)
library_root_logger
.
setLevel
(
logging
.
NOTSET
)
_default_handler
=
None
def
get_log_levels_dict
():
return
log_levels
def
get_logger
(
name
:
Optional
[
str
]
=
None
)
->
logging
.
Logger
:
"""
Return a logger with the specified name.
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
"""
if
name
is
None
:
name
=
_get_library_name
()
_configure_library_root_logger
()
return
logging
.
getLogger
(
name
)
def
get_verbosity
()
->
int
:
"""
Return the current level for the 🤗 Transformers's root logger as an int.
Returns:
`int`: The logging level.
<Tip>
🤗 Transformers has following logging levels:
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- 40: `diffusers.logging.ERROR`
- 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- 20: `diffusers.logging.INFO`
- 10: `diffusers.logging.DEBUG`
</Tip>"""
_configure_library_root_logger
()
return
_get_library_root_logger
().
getEffectiveLevel
()
def
set_verbosity
(
verbosity
:
int
)
->
None
:
"""
Set the verbosity level for the 🤗 Transformers's root logger.
Args:
verbosity (`int`):
Logging level, e.g., one of:
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- `diffusers.logging.ERROR`
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- `diffusers.logging.INFO`
- `diffusers.logging.DEBUG`
"""
_configure_library_root_logger
()
_get_library_root_logger
().
setLevel
(
verbosity
)
def
set_verbosity_info
():
"""Set the verbosity to the `INFO` level."""
return
set_verbosity
(
INFO
)
def
set_verbosity_warning
():
"""Set the verbosity to the `WARNING` level."""
return
set_verbosity
(
WARNING
)
def
set_verbosity_debug
():
"""Set the verbosity to the `DEBUG` level."""
return
set_verbosity
(
DEBUG
)
def
set_verbosity_error
():
"""Set the verbosity to the `ERROR` level."""
return
set_verbosity
(
ERROR
)
def
disable_default_handler
()
->
None
:
"""Disable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
_default_handler
is
not
None
_get_library_root_logger
().
removeHandler
(
_default_handler
)
def
enable_default_handler
()
->
None
:
"""Enable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
_default_handler
is
not
None
_get_library_root_logger
().
addHandler
(
_default_handler
)
def
add_handler
(
handler
:
logging
.
Handler
)
->
None
:
"""adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
handler
is
not
None
_get_library_root_logger
().
addHandler
(
handler
)
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
"""removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
handler
is
not
None
and
handler
not
in
_get_library_root_logger
().
handlers
_get_library_root_logger
().
removeHandler
(
handler
)
def
disable_propagation
()
->
None
:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
_configure_library_root_logger
()
_get_library_root_logger
().
propagate
=
False
def
enable_propagation
()
->
None
:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""
_configure_library_root_logger
()
_get_library_root_logger
().
propagate
=
True
def
enable_explicit_format
()
->
None
:
"""
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
```
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
```
All handlers currently bound to the root logger are affected by this method.
"""
handlers
=
_get_library_root_logger
().
handlers
for
handler
in
handlers
:
formatter
=
logging
.
Formatter
(
"[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s"
)
handler
.
setFormatter
(
formatter
)
def
reset_format
()
->
None
:
"""
Resets the formatting for HuggingFace Transformers's loggers.
All handlers currently bound to the root logger are affected by this method.
"""
handlers
=
_get_library_root_logger
().
handlers
for
handler
in
handlers
:
handler
.
setFormatter
(
None
)
def
warning_advice
(
self
,
*
args
,
**
kwargs
):
"""
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
no_advisory_warnings
=
os
.
getenv
(
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
,
False
)
if
no_advisory_warnings
:
return
self
.
warning
(
*
args
,
**
kwargs
)
logging
.
Logger
.
warning_advice
=
warning_advice
class
EmptyTqdm
:
"""Dummy tqdm which doesn't do anything."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
self
.
_iterator
=
args
[
0
]
if
args
else
None
def
__iter__
(
self
):
return
iter
(
self
.
_iterator
)
def
__getattr__
(
self
,
_
):
"""Return empty function."""
def
empty_fn
(
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
return
return
empty_fn
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
type_
,
value
,
traceback
):
return
class
_tqdm_cls
:
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
_tqdm_active
:
return
tqdm_lib
.
tqdm
(
*
args
,
**
kwargs
)
else
:
return
EmptyTqdm
(
*
args
,
**
kwargs
)
def
set_lock
(
self
,
*
args
,
**
kwargs
):
self
.
_lock
=
None
if
_tqdm_active
:
return
tqdm_lib
.
tqdm
.
set_lock
(
*
args
,
**
kwargs
)
def
get_lock
(
self
):
if
_tqdm_active
:
return
tqdm_lib
.
tqdm
.
get_lock
()
tqdm
=
_tqdm_cls
()
def
is_progress_bar_enabled
()
->
bool
:
"""Return a boolean indicating whether tqdm progress bars are enabled."""
global
_tqdm_active
return
bool
(
_tqdm_active
)
def
enable_progress_bar
():
"""Enable tqdm progress bar."""
global
_tqdm_active
_tqdm_active
=
True
def
disable_progress_bar
():
"""Disable tqdm progress bar."""
global
_tqdm_active
_tqdm_active
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment