Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
750bd792
You need to sign in or sign up before continuing.
Unverified
Commit
750bd792
authored
Aug 21, 2024
by
Dhruv Nair
Committed by
GitHub
Aug 21, 2024
Browse files
[Single File] Fix configuring scheduler via legacy kwargs (#9229)
update
parent
214372aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
5 deletions
+18
-5
src/diffusers/loaders/single_file.py
src/diffusers/loaders/single_file.py
+2
-2
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+16
-3
No files found.
src/diffusers/loaders/single_file.py
View file @
750bd792
...
@@ -23,6 +23,7 @@ from packaging import version
...
@@ -23,6 +23,7 @@ from packaging import version
from
..utils
import
deprecate
,
is_transformers_available
,
logging
from
..utils
import
deprecate
,
is_transformers_available
,
logging
from
.single_file_utils
import
(
from
.single_file_utils
import
(
SingleFileComponentError
,
SingleFileComponentError
,
_is_legacy_scheduler_kwargs
,
_is_model_weights_in_cached_folder
,
_is_model_weights_in_cached_folder
,
_legacy_load_clip_tokenizer
,
_legacy_load_clip_tokenizer
,
_legacy_load_safety_checker
,
_legacy_load_safety_checker
,
...
@@ -42,7 +43,6 @@ logger = logging.get_logger(__name__)
...
@@ -42,7 +43,6 @@ logger = logging.get_logger(__name__)
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
SINGLE_FILE_OPTIONAL_COMPONENTS
=
[
"safety_checker"
]
SINGLE_FILE_OPTIONAL_COMPONENTS
=
[
"safety_checker"
]
if
is_transformers_available
():
if
is_transformers_available
():
import
transformers
import
transformers
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
...
@@ -135,7 +135,7 @@ def load_single_file_sub_model(
...
@@ -135,7 +135,7 @@ def load_single_file_sub_model(
class_obj
,
checkpoint
=
checkpoint
,
config
=
cached_model_config_path
,
local_files_only
=
local_files_only
class_obj
,
checkpoint
=
checkpoint
,
config
=
cached_model_config_path
,
local_files_only
=
local_files_only
)
)
elif
is_diffusers_scheduler
and
is_legacy_loading
:
elif
is_diffusers_scheduler
and
(
is_legacy_loading
or
_is_legacy_scheduler_kwargs
(
kwargs
))
:
loaded_sub_model
=
_legacy_load_scheduler
(
loaded_sub_model
=
_legacy_load_scheduler
(
class_obj
,
checkpoint
=
checkpoint
,
component_name
=
name
,
original_config
=
original_config
,
**
kwargs
class_obj
,
checkpoint
=
checkpoint
,
component_name
=
name
,
original_config
=
original_config
,
**
kwargs
)
)
...
...
src/diffusers/loaders/single_file_utils.py
View file @
750bd792
...
@@ -269,6 +269,7 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
...
@@ -269,6 +269,7 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
]
]
OPEN_CLIP_PREFIX
=
"conditioner.embedders.0.model."
OPEN_CLIP_PREFIX
=
"conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
=
1024
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
=
1024
SCHEDULER_LEGACY_KWARGS
=
[
"prediction_type"
,
"scheduler_type"
]
VALID_URL_PREFIXES
=
[
"https://huggingface.co/"
,
"huggingface.co/"
,
"hf.co/"
,
"https://hf.co/"
]
VALID_URL_PREFIXES
=
[
"https://huggingface.co/"
,
"huggingface.co/"
,
"hf.co/"
,
"https://hf.co/"
]
...
@@ -318,6 +319,10 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
...
@@ -318,6 +319,10 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
return
weights_exist
return
weights_exist
def
_is_legacy_scheduler_kwargs
(
kwargs
):
return
any
(
k
in
SCHEDULER_LEGACY_KWARGS
for
k
in
kwargs
.
keys
())
def
load_single_file_checkpoint
(
def
load_single_file_checkpoint
(
pretrained_model_link_or_path
,
pretrained_model_link_or_path
,
force_download
=
False
,
force_download
=
False
,
...
@@ -1479,14 +1484,22 @@ def _legacy_load_scheduler(
...
@@ -1479,14 +1484,22 @@ def _legacy_load_scheduler(
if
scheduler_type
is
not
None
:
if
scheduler_type
is
not
None
:
deprecation_message
=
(
deprecation_message
=
(
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`."
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`
\n\n
"
"Example:
\n\n
"
"from diffusers import StableDiffusionPipeline, DDIMScheduler
\n\n
"
"scheduler = DDIMScheduler()
\n
"
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)
\n
"
)
)
deprecate
(
"scheduler_type"
,
"1.0.0"
,
deprecation_message
)
deprecate
(
"scheduler_type"
,
"1.0.0"
,
deprecation_message
)
if
prediction_type
is
not
None
:
if
prediction_type
is
not
None
:
deprecation_message
=
(
deprecation_message
=
(
"Please configure an instance of a Scheduler with the appropriate `prediction_type` "
"Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
"and pass the object directly to the `scheduler` argument in `from_single_file`."
"pass the object directly to the `scheduler` argument in `from_single_file`.
\n\n
"
"Example:
\n\n
"
"from diffusers import StableDiffusionPipeline, DDIMScheduler
\n\n
"
'scheduler = DDIMScheduler(prediction_type="v_prediction")
\n
'
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)
\n
"
)
)
deprecate
(
"prediction_type"
,
"1.0.0"
,
deprecation_message
)
deprecate
(
"prediction_type"
,
"1.0.0"
,
deprecation_message
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment