Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
750bd792
Unverified
Commit
750bd792
authored
Aug 21, 2024
by
Dhruv Nair
Committed by
GitHub
Aug 21, 2024
Browse files
[Single File] Fix configuring scheduler via legacy kwargs (#9229)
update
parent
214372aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
5 deletions
+18
-5
src/diffusers/loaders/single_file.py
src/diffusers/loaders/single_file.py
+2
-2
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+16
-3
No files found.
src/diffusers/loaders/single_file.py
View file @
750bd792
...
@@ -23,6 +23,7 @@ from packaging import version
...
@@ -23,6 +23,7 @@ from packaging import version
from
..utils
import
deprecate
,
is_transformers_available
,
logging
from
..utils
import
deprecate
,
is_transformers_available
,
logging
from
.single_file_utils
import
(
from
.single_file_utils
import
(
SingleFileComponentError
,
SingleFileComponentError
,
_is_legacy_scheduler_kwargs
,
_is_model_weights_in_cached_folder
,
_is_model_weights_in_cached_folder
,
_legacy_load_clip_tokenizer
,
_legacy_load_clip_tokenizer
,
_legacy_load_safety_checker
,
_legacy_load_safety_checker
,
...
@@ -42,7 +43,6 @@ logger = logging.get_logger(__name__)
...
@@ -42,7 +43,6 @@ logger = logging.get_logger(__name__)
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
SINGLE_FILE_OPTIONAL_COMPONENTS
=
[
"safety_checker"
]
SINGLE_FILE_OPTIONAL_COMPONENTS
=
[
"safety_checker"
]
if
is_transformers_available
():
if
is_transformers_available
():
import
transformers
import
transformers
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
...
@@ -135,7 +135,7 @@ def load_single_file_sub_model(
...
@@ -135,7 +135,7 @@ def load_single_file_sub_model(
class_obj
,
checkpoint
=
checkpoint
,
config
=
cached_model_config_path
,
local_files_only
=
local_files_only
class_obj
,
checkpoint
=
checkpoint
,
config
=
cached_model_config_path
,
local_files_only
=
local_files_only
)
)
elif
is_diffusers_scheduler
and
is_legacy_loading
:
elif
is_diffusers_scheduler
and
(
is_legacy_loading
or
_is_legacy_scheduler_kwargs
(
kwargs
))
:
loaded_sub_model
=
_legacy_load_scheduler
(
loaded_sub_model
=
_legacy_load_scheduler
(
class_obj
,
checkpoint
=
checkpoint
,
component_name
=
name
,
original_config
=
original_config
,
**
kwargs
class_obj
,
checkpoint
=
checkpoint
,
component_name
=
name
,
original_config
=
original_config
,
**
kwargs
)
)
...
...
src/diffusers/loaders/single_file_utils.py
View file @
750bd792
...
@@ -269,6 +269,7 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
...
@@ -269,6 +269,7 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
]
]
OPEN_CLIP_PREFIX
=
"conditioner.embedders.0.model."
OPEN_CLIP_PREFIX
=
"conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
=
1024
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
=
1024
SCHEDULER_LEGACY_KWARGS
=
[
"prediction_type"
,
"scheduler_type"
]
VALID_URL_PREFIXES
=
[
"https://huggingface.co/"
,
"huggingface.co/"
,
"hf.co/"
,
"https://hf.co/"
]
VALID_URL_PREFIXES
=
[
"https://huggingface.co/"
,
"huggingface.co/"
,
"hf.co/"
,
"https://hf.co/"
]
...
@@ -318,6 +319,10 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
...
@@ -318,6 +319,10 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
return
weights_exist
return
weights_exist
def
_is_legacy_scheduler_kwargs
(
kwargs
):
return
any
(
k
in
SCHEDULER_LEGACY_KWARGS
for
k
in
kwargs
.
keys
())
def
load_single_file_checkpoint
(
def
load_single_file_checkpoint
(
pretrained_model_link_or_path
,
pretrained_model_link_or_path
,
force_download
=
False
,
force_download
=
False
,
...
@@ -1479,14 +1484,22 @@ def _legacy_load_scheduler(
...
@@ -1479,14 +1484,22 @@ def _legacy_load_scheduler(
if
scheduler_type
is
not
None
:
if
scheduler_type
is
not
None
:
deprecation_message
=
(
deprecation_message
=
(
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`."
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`
\n\n
"
"Example:
\n\n
"
"from diffusers import StableDiffusionPipeline, DDIMScheduler
\n\n
"
"scheduler = DDIMScheduler()
\n
"
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)
\n
"
)
)
deprecate
(
"scheduler_type"
,
"1.0.0"
,
deprecation_message
)
deprecate
(
"scheduler_type"
,
"1.0.0"
,
deprecation_message
)
if
prediction_type
is
not
None
:
if
prediction_type
is
not
None
:
deprecation_message
=
(
deprecation_message
=
(
"Please configure an instance of a Scheduler with the appropriate `prediction_type` "
"Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
"and pass the object directly to the `scheduler` argument in `from_single_file`."
"pass the object directly to the `scheduler` argument in `from_single_file`.
\n\n
"
"Example:
\n\n
"
"from diffusers import StableDiffusionPipeline, DDIMScheduler
\n\n
"
'scheduler = DDIMScheduler(prediction_type="v_prediction")
\n
'
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)
\n
"
)
)
deprecate
(
"prediction_type"
,
"1.0.0"
,
deprecation_message
)
deprecate
(
"prediction_type"
,
"1.0.0"
,
deprecation_message
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment