Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
27359ae0
Commit
27359ae0
authored
Jun 20, 2022
by
patil-suraj
Browse files
remove wrong file
parent
95a45f5b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
289 deletions
+0
-289
1
1
+0
-289
No files found.
1
deleted
100644 → 0
View file @
95a45f5b
#
coding
=
utf
-
8
#
Copyright
2022
The
HuggingFace
Inc
.
team
.
#
Copyright
(
c
)
2022
,
NVIDIA
CORPORATION
.
All
rights
reserved
.
#
#
Licensed
under
the
Apache
License
,
Version
2.0
(
the
"License"
);
#
you
may
not
use
this
file
except
in
compliance
with
the
License
.
#
You
may
obtain
a
copy
of
the
License
at
#
#
http
://
www
.
apache
.
org
/
licenses
/
LICENSE
-
2.0
#
#
Unless
required
by
applicable
law
or
agreed
to
in
writing
,
software
#
distributed
under
the
License
is
distributed
on
an
"AS IS"
BASIS
,
#
WITHOUT
WARRANTIES
OR
CONDITIONS
OF
ANY
KIND
,
either
express
or
implied
.
#
See
the
License
for
the
specific
language
governing
permissions
and
#
limitations
under
the
License
.
""" ConfigMixinuration base class and utilities."""
import
inspect
import
json
import
os
import
re
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
huggingface_hub
import
hf_hub_download
from
requests
import
HTTPError
from
.
import
__version__
from
.
utils
import
(
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
logging
,
)
logger
=
logging
.
get_logger
(
__name__
)
_re_configuration_file
=
re
.
compile
(
r
"config\.(.*)\.json"
)
class
ConfigMixin
:
r
"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
"""
config_name
=
None
def
register_to_config
(
self
,
**
kwargs
):
if
self
.
config_name
is
None
:
raise
NotImplementedError
(
f
"Make sure that {self.__class__} has defined a class name `config_name`"
)
kwargs
[
"_class_name"
]
=
self
.
__class__
.
__name__
kwargs
[
"_diffusers_version"
]
=
__version__
for
key
,
value
in
kwargs
.
items
():
try
:
setattr
(
self
,
key
,
value
)
except
AttributeError
as
err
:
logger
.
error
(
f
"Can't set {key} with value {value} for {self}"
)
raise
err
if
not
hasattr
(
self
,
"_internal_dict"
):
internal_dict
=
kwargs
else
:
previous_dict
=
dict
(
self
.
_internal_dict
)
internal_dict
=
{**
self
.
_internal_dict
,
**
kwargs
}
logger
.
debug
(
f
"Updating config from {previous_dict} to {internal_dict}"
)
self
.
_internal_dict
=
FrozenDict
(
internal_dict
)
def
save_config
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~ConfigMixin.from_config`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if
os
.
path
.
isfile
(
save_directory
):
raise
AssertionError
(
f
"Provided path ({save_directory}) should be a directory, not a file"
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
#
If
we
save
using
the
predefined
names
,
we
can
load
using
`
from_config
`
output_config_file
=
os
.
path
.
join
(
save_directory
,
self
.
config_name
)
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in {output_config_file}"
)
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
model
=
cls
(**
init_dict
)
if
return_unused_kwargs
:
return
model
,
unused_kwargs
else
:
return
model
@
classmethod
def
get_config_dict
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
user_agent
=
{
"file_type"
:
"config"
}
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
if
cls
.
config_name
is
None
:
raise
ValueError
(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)):
#
Load
from
a
PyTorch
checkpoint
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)
else
:
raise
EnvironmentError
(
f
"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
)
else
:
try
:
#
Load
from
URL
or
cache
if
already
cached
config_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
filename
=
cls
.
config_name
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
" on 'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
" pass `use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
" this model name. Check the model page at"
f
" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
"There was a specific connection error when trying to load"
f
" {pretrained_model_name_or_path}:
\n
{err}"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f
" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f
" directory containing a {cls.config_name} file.
\n
Checkout your internet connection or see how to"
" run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f
"containing a {cls.config_name} file"
)
try
:
#
Load
config
dict
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
f
"It looks like the config file at '{config_file}' is not a valid JSON file."
)
return
config_dict
@
classmethod
def
extract_init_dict
(
cls
,
config_dict
,
**
kwargs
):
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
.
remove
(
"self"
)
init_dict
=
{}
for
key
in
expected_keys
:
if
key
in
kwargs
:
#
overwrite
key
init_dict
[
key
]
=
kwargs
.
pop
(
key
)
elif
key
in
config_dict
:
#
use
value
from
config
dict
init_dict
[
key
]
=
config_dict
.
pop
(
key
)
unused_kwargs
=
config_dict
.
update
(
kwargs
)
passed_keys
=
set
(
init_dict
.
keys
())
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warning
(
f
"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)
return
init_dict
,
unused_kwargs
@
classmethod
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
with
open
(
json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
text
=
reader
.
read
()
return
json
.
loads
(
text
)
def
__repr__
(
self
):
return
f
"{self.__class__.__name__} {self.to_json_string()}"
@
property
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
_internal_dict
def
to_json_string
(
self
)
->
str
:
"""
Serializes this instance to a JSON string.
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
import
ipdb
;
ipdb
.
set_trace
()
config_dict
=
self
.
_internal_dict
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
"""
with
open
(
json_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
self
.
to_json_string
())
class
FrozenDict
(
OrderedDict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(*
args
,
**
kwargs
)
for
key
,
value
in
self
.
items
():
setattr
(
self
,
key
,
value
)
self
.
__frozen
=
True
def
__delitem__
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)
def
setdefault
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)
def
pop
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``pop`` on a {self.__class__.__name__} instance."
)
def
update
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``update`` on a {self.__class__.__name__} instance."
)
def
__setattr__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super
().
__setattr__
(
name
,
value
)
def
__setitem__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super
().
__setitem__
(
name
,
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment