Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
09e1b0b4
Commit
09e1b0b4
authored
Jun 09, 2022
by
Patrick von Platen
Browse files
remove transformers dependency
parent
5a784f98
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
552 additions
and
181 deletions
+552
-181
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+71
-81
src/diffusers/dynamic_modules_utils.py
src/diffusers/dynamic_modules_utils.py
+4
-12
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+60
-80
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+25
-8
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+48
-0
src/diffusers/utils/logging.py
src/diffusers/utils/logging.py
+344
-0
No files found.
src/diffusers/configuration_utils.py
View file @
09e1b0b4
...
...
@@ -24,18 +24,19 @@ import re
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
requests
import
HTTPError
from
transformers.utils
import
(
from
huggingface_hub
import
hf_hub_download
from
.utils
import
(
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
DIFFUSERS_CACHE
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
,
logging
,
)
from
.
import
__version__
...
...
@@ -89,13 +90,12 @@ class ConfigMixin:
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
@
classmethod
def
get_config_dict
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
None
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
...
...
@@ -105,85 +105,77 @@ class ConfigMixin:
user_agent
=
{
"file_type"
:
"config"
}
if
is_offline_mode
()
and
not
local_files_only
:
logger
.
info
(
"Offline mode: forcing local_files_only=True"
)
local_files_only
=
True
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
else
:
configuration_file
=
cls
.
config_name
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
configuration_file
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)):
# Load from a PyTorch checkpoint
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)
else
:
config_file
=
hf_bucket_url
(
pretrained_model_name_or_path
,
filename
=
configuration_file
,
revision
=
revision
,
mirror
=
None
raise
EnvironmentError
(
f
"Error no file named
{
cls
.
config_name
}
found in directory
{
pretrained_model_name_or_path
}
."
)
else
:
try
:
# Load from URL or cache if already cached
config_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
filename
=
cls
.
config_name
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
try
:
# Load from URL or cache if already cached
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f
"model name. Check the model page at 'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for "
"available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
configuration_file
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it in"
f
" the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a directory"
f
" containing a
{
configuration_file
}
file.
\n
Checkout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load config for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a
{
configuration_file
}
file"
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f
"model name. Check the model page at 'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for "
"available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
cls
.
config_name
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it in"
f
" the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a directory"
f
" containing a
{
cls
.
config_name
}
file.
\n
Checkout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load config for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a
{
cls
.
config_name
}
file"
)
try
:
# Load config dict
config_dict
=
cls
.
_dict_from_json_file
(
resolved_
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
f
"It looks like the config file at '
{
resolved_
config_file
}
' is not a valid JSON file."
)
try
:
# Load config dict
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
f
"It looks like the config file at '
{
config_file
}
' is not a valid JSON file."
)
if
resolved_config_file
==
config_file
:
logger
.
info
(
f
"loading configuration file
{
config_file
}
"
)
else
:
logger
.
info
(
f
"loading configuration file
{
config_file
}
from cache at
{
resolved_config_file
}
"
)
return
config_dict
@
classmethod
...
...
@@ -199,9 +191,7 @@ class ConfigMixin:
# use value from config dict
init_dict
[
key
]
=
config_dict
.
pop
(
key
)
unused_kwargs
=
config_dict
.
update
(
kwargs
)
passed_keys
=
set
(
init_dict
.
keys
())
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warn
(
...
...
src/diffusers/dynamic_modules_utils.py
View file @
09e1b0b4
...
...
@@ -22,16 +22,8 @@ import sys
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Union
from
huggingface_hub
import
HfFolder
,
model_info
from
transformers.utils
import
(
HF_MODULES_CACHE
,
TRANSFORMERS_DYNAMIC_MODULE_NAME
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
logging
,
)
from
huggingface_hub
import
cached_download
from
.utils
import
HF_MODULES_CACHE
,
DIFFUSERS_DYNAMIC_MODULE_NAME
,
logging
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -219,7 +211,7 @@ def get_cached_module_file(
try
:
# Load from URL or cache if already cached
resolved_module_file
=
cached_
path
(
resolved_module_file
=
cached_
download
(
module_file_or_url
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
...
...
@@ -237,7 +229,7 @@ def get_cached_module_file(
modules_needed
=
check_imports
(
resolved_module_file
)
# Now we move the module inside our cached dynamic modules.
full_submodule
=
TRANSFORM
ERS_DYNAMIC_MODULE_NAME
+
os
.
path
.
sep
+
submodule
full_submodule
=
DIFFUS
ERS_DYNAMIC_MODULE_NAME
+
os
.
path
.
sep
+
submodule
create_dynamic_module
(
full_submodule
)
submodule_path
=
Path
(
HF_MODULES_CACHE
)
/
full_submodule
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
...
...
src/diffusers/modeling_utils.py
View file @
09e1b0b4
...
...
@@ -21,18 +21,15 @@ import torch
from
torch
import
Tensor
,
device
from
requests
import
HTTPError
from
huggingface_hub
import
hf_hub_download
# CHANGE to diffusers.utils
from
transformers.utils
import
(
from
.utils
import
(
CONFIG_NAME
,
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
,
logging
,
)
...
...
@@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module):
</Tip>
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
None
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
ignore_mismatched_sizes
=
kwargs
.
pop
(
"ignore_mismatched_sizes"
,
False
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
...
...
@@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
mirror
=
kwargs
.
pop
(
"mirror"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
if
is_offline_mode
()
and
not
local_files_only
:
logger
.
info
(
"Offline mode: forcing local_files_only=True"
)
local_files_only
=
True
# Load config if we don't provide a configuration
config_path
=
pretrained_model_name_or_path
model
,
unused_kwargs
=
cls
.
from_config
(
...
...
@@ -353,83 +345,71 @@ class ModelMixin(torch.nn.Module):
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)):
# Load from a PyTorch checkpoint
archive
_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
model
_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
raise
EnvironmentError
(
f
"Error no file named
{
WEIGHTS_NAME
}
found in directory
{
pretrained_model_name_or_path
}
."
)
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
else
:
filename
=
WEIGHTS_NAME
archive_file
=
hf_bucket_url
(
pretrained_model_name_or_path
,
filename
=
filename
,
revision
=
revision
,
mirror
=
mirror
)
try
:
# Load from URL or cache if already cached
model_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
filename
=
WEIGHTS_NAME
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
try
:
# Load from URL or cache if already cached
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
model_file
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it"
f
" in the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a"
f
" directory containing a file named
{
WEIGHTS_NAME
}
or"
"
\n
Checkout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load the model for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a file named
{
WEIGHTS_NAME
}
"
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
filename
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it"
f
" in the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a"
f
" directory containing a file named
{
WEIGHTS_NAME
}
or"
"
\n
Checkout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load the model for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a file named
{
WEIGHTS_NAME
}
"
# restore default dtype
state_dict
=
load_state_dict
(
model_file
)
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
state_dict
,
model_file
,
pretrained_model_name_or_path
,
ignore_mismatched_sizes
=
ignore_mismatched_sizes
,
)
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
f
"loading weights file
{
archive_file
}
"
)
else
:
logger
.
info
(
f
"loading weights file
{
archive_file
}
from cache at
{
resolved_archive_file
}
"
)
# restore default dtype
state_dict
=
load_state_dict
(
resolved_archive_file
)
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
state_dict
,
resolved_archive_file
,
pretrained_model_name_or_path
,
ignore_mismatched_sizes
=
ignore_mismatched_sizes
,
)
# Set model in evaluation mode to deactivate DropOut modules by default
model
.
eval
()
...
...
src/diffusers/pipeline_utils.py
View file @
09e1b0b4
...
...
@@ -19,8 +19,7 @@ import os
from
typing
import
Optional
,
Union
from
huggingface_hub
import
snapshot_download
# CHANGE to diffusers.utils
from
transformers.utils
import
logging
from
.utils
import
logging
,
DIFFUSERS_CACHE
from
.configuration_utils
import
ConfigMixin
from
.dynamic_modules_utils
import
get_class_from_dynamic_module
...
...
@@ -55,14 +54,13 @@ class DiffusionPipeline(ConfigMixin):
class_name
=
module
.
__class__
.
__name__
register_dict
=
{
name
:
(
library
,
class_name
)}
# save model index config
self
.
register
(
**
register_dict
)
# set models
setattr
(
self
,
name
,
module
)
register_dict
=
{
"_module"
:
self
.
__module__
.
split
(
"."
)[
-
1
]
+
".py"
}
self
.
register
(
**
register_dict
)
...
...
@@ -94,22 +92,41 @@ class DiffusionPipeline(ConfigMixin):
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
r
"""
Add docstrings
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
output_loading_info
=
kwargs
.
pop
(
"output_loading_info"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
)
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
output_loading_info
=
output_loading_info
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
)
else
:
cached_folder
=
pretrained_model_name_or_path
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
module
=
config_dict
[
"_module"
]
class_name_
=
config_dict
[
"_class_name"
]
if
class_name_
==
cls
.
__name__
:
pipeline_class
=
cls
else
:
pipeline_class
=
get_class_from_dynamic_module
(
cached_folder
,
module
,
class_name_
,
cached_folder
)
init_dict
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
...
...
src/diffusers/utils/__init__.py
0 → 100644
View file @
09e1b0b4
#!/usr/bin/env python
# coding=utf-8
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
requests.exceptions
import
HTTPError
hf_cache_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
"HF_HOME"
,
os
.
path
.
join
(
os
.
getenv
(
"XDG_CACHE_HOME"
,
"~/.cache"
),
"huggingface"
))
)
default_cache_path
=
os
.
path
.
join
(
hf_cache_home
,
"diffusers"
)
CONFIG_NAME
=
"config.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT
=
"https://huggingface.co"
DIFFUSERS_CACHE
=
default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME
=
"diffusers_modules"
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
class
RepositoryNotFoundError
(
HTTPError
):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
not have access to.
"""
class
EntryNotFoundError
(
HTTPError
):
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
class
RevisionNotFoundError
(
HTTPError
):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
src/diffusers/utils/logging.py
0 → 100644
View file @
09e1b0b4
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities."""
import
logging
import
os
import
sys
import
threading
from
logging
import
CRITICAL
# NOQA
from
logging
import
DEBUG
# NOQA
from
logging
import
ERROR
# NOQA
from
logging
import
FATAL
# NOQA
from
logging
import
INFO
# NOQA
from
logging
import
NOTSET
# NOQA
from
logging
import
WARN
# NOQA
from
logging
import
WARNING
# NOQA
from
typing
import
Optional
from
tqdm
import
auto
as
tqdm_lib
_lock
=
threading
.
Lock
()
_default_handler
:
Optional
[
logging
.
Handler
]
=
None
log_levels
=
{
"debug"
:
logging
.
DEBUG
,
"info"
:
logging
.
INFO
,
"warning"
:
logging
.
WARNING
,
"error"
:
logging
.
ERROR
,
"critical"
:
logging
.
CRITICAL
,
}
_default_log_level
=
logging
.
WARNING
_tqdm_active
=
True
def
_get_default_logging_level
():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level`
"""
env_level_str
=
os
.
getenv
(
"TRANSFORMERS_VERBOSITY"
,
None
)
if
env_level_str
:
if
env_level_str
in
log_levels
:
return
log_levels
[
env_level_str
]
else
:
logging
.
getLogger
().
warning
(
f
"Unknown option TRANSFORMERS_VERBOSITY=
{
env_level_str
}
, "
f
"has to be one of:
{
', '
.
join
(
log_levels
.
keys
())
}
"
)
return
_default_log_level
def
_get_library_name
()
->
str
:
return
__name__
.
split
(
"."
)[
0
]
def
_get_library_root_logger
()
->
logging
.
Logger
:
return
logging
.
getLogger
(
_get_library_name
())
def
_configure_library_root_logger
()
->
None
:
global
_default_handler
with
_lock
:
if
_default_handler
:
# This library has already configured the library root logger.
return
_default_handler
=
logging
.
StreamHandler
()
# Set sys.stderr as stream.
_default_handler
.
flush
=
sys
.
stderr
.
flush
# Apply our default configuration to the library root logger.
library_root_logger
=
_get_library_root_logger
()
library_root_logger
.
addHandler
(
_default_handler
)
library_root_logger
.
setLevel
(
_get_default_logging_level
())
library_root_logger
.
propagate
=
False
def
_reset_library_root_logger
()
->
None
:
global
_default_handler
with
_lock
:
if
not
_default_handler
:
return
library_root_logger
=
_get_library_root_logger
()
library_root_logger
.
removeHandler
(
_default_handler
)
library_root_logger
.
setLevel
(
logging
.
NOTSET
)
_default_handler
=
None
def
get_log_levels_dict
():
return
log_levels
def
get_logger
(
name
:
Optional
[
str
]
=
None
)
->
logging
.
Logger
:
"""
Return a logger with the specified name.
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
"""
if
name
is
None
:
name
=
_get_library_name
()
_configure_library_root_logger
()
return
logging
.
getLogger
(
name
)
def
get_verbosity
()
->
int
:
"""
Return the current level for the 🤗 Transformers's root logger as an int.
Returns:
`int`: The logging level.
<Tip>
🤗 Transformers has following logging levels:
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- 40: `diffusers.logging.ERROR`
- 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- 20: `diffusers.logging.INFO`
- 10: `diffusers.logging.DEBUG`
</Tip>"""
_configure_library_root_logger
()
return
_get_library_root_logger
().
getEffectiveLevel
()
def
set_verbosity
(
verbosity
:
int
)
->
None
:
"""
Set the verbosity level for the 🤗 Transformers's root logger.
Args:
verbosity (`int`):
Logging level, e.g., one of:
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- `diffusers.logging.ERROR`
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- `diffusers.logging.INFO`
- `diffusers.logging.DEBUG`
"""
_configure_library_root_logger
()
_get_library_root_logger
().
setLevel
(
verbosity
)
def
set_verbosity_info
():
"""Set the verbosity to the `INFO` level."""
return
set_verbosity
(
INFO
)
def
set_verbosity_warning
():
"""Set the verbosity to the `WARNING` level."""
return
set_verbosity
(
WARNING
)
def
set_verbosity_debug
():
"""Set the verbosity to the `DEBUG` level."""
return
set_verbosity
(
DEBUG
)
def
set_verbosity_error
():
"""Set the verbosity to the `ERROR` level."""
return
set_verbosity
(
ERROR
)
def
disable_default_handler
()
->
None
:
"""Disable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
_default_handler
is
not
None
_get_library_root_logger
().
removeHandler
(
_default_handler
)
def
enable_default_handler
()
->
None
:
"""Enable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
_default_handler
is
not
None
_get_library_root_logger
().
addHandler
(
_default_handler
)
def
add_handler
(
handler
:
logging
.
Handler
)
->
None
:
"""adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
handler
is
not
None
_get_library_root_logger
().
addHandler
(
handler
)
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
"""removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
handler
is
not
None
and
handler
not
in
_get_library_root_logger
().
handlers
_get_library_root_logger
().
removeHandler
(
handler
)
def
disable_propagation
()
->
None
:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
_configure_library_root_logger
()
_get_library_root_logger
().
propagate
=
False
def
enable_propagation
()
->
None
:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""
_configure_library_root_logger
()
_get_library_root_logger
().
propagate
=
True
def
enable_explicit_format
()
->
None
:
"""
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
```
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
```
All handlers currently bound to the root logger are affected by this method.
"""
handlers
=
_get_library_root_logger
().
handlers
for
handler
in
handlers
:
formatter
=
logging
.
Formatter
(
"[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s"
)
handler
.
setFormatter
(
formatter
)
def
reset_format
()
->
None
:
"""
Resets the formatting for HuggingFace Transformers's loggers.
All handlers currently bound to the root logger are affected by this method.
"""
handlers
=
_get_library_root_logger
().
handlers
for
handler
in
handlers
:
handler
.
setFormatter
(
None
)
def
warning_advice
(
self
,
*
args
,
**
kwargs
):
"""
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
no_advisory_warnings
=
os
.
getenv
(
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
,
False
)
if
no_advisory_warnings
:
return
self
.
warning
(
*
args
,
**
kwargs
)
logging
.
Logger
.
warning_advice
=
warning_advice
class
EmptyTqdm
:
"""Dummy tqdm which doesn't do anything."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
self
.
_iterator
=
args
[
0
]
if
args
else
None
def
__iter__
(
self
):
return
iter
(
self
.
_iterator
)
def
__getattr__
(
self
,
_
):
"""Return empty function."""
def
empty_fn
(
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
return
return
empty_fn
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
type_
,
value
,
traceback
):
return
class
_tqdm_cls
:
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
_tqdm_active
:
return
tqdm_lib
.
tqdm
(
*
args
,
**
kwargs
)
else
:
return
EmptyTqdm
(
*
args
,
**
kwargs
)
def
set_lock
(
self
,
*
args
,
**
kwargs
):
self
.
_lock
=
None
if
_tqdm_active
:
return
tqdm_lib
.
tqdm
.
set_lock
(
*
args
,
**
kwargs
)
def
get_lock
(
self
):
if
_tqdm_active
:
return
tqdm_lib
.
tqdm
.
get_lock
()
tqdm
=
_tqdm_cls
()
def
is_progress_bar_enabled
()
->
bool
:
"""Return a boolean indicating whether tqdm progress bars are enabled."""
global
_tqdm_active
return
bool
(
_tqdm_active
)
def
enable_progress_bar
():
"""Enable tqdm progress bar."""
global
_tqdm_active
_tqdm_active
=
True
def
disable_progress_bar
():
"""Disable tqdm progress bar."""
global
_tqdm_active
_tqdm_active
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment