Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8efd9ce7
Unverified
Commit
8efd9ce7
authored
Mar 13, 2024
by
Sayak Paul
Committed by
GitHub
Mar 13, 2024
Browse files
[Chore] clean residue from copy-pasting in the UNet single file loader (#7295)
clean residue from copy-pasting
parent
299c16d0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
src/diffusers/loaders/unet.py
src/diffusers/loaders/unet.py
+6
-6
No files found.
src/diffusers/loaders/unet.py
View file @
8efd9ce7
...
@@ -905,14 +905,14 @@ class UNet2DConditionLoadersMixin:
...
@@ -905,14 +905,14 @@ class UNet2DConditionLoadersMixin:
class
FromOriginalUNetMixin
:
class
FromOriginalUNetMixin
:
"""
"""
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`
ControlNetModel
`].
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`
StableCascadeUNet
`].
"""
"""
@
classmethod
@
classmethod
@
validate_hf_hub_args
@
validate_hf_hub_args
def
from_single_file
(
cls
,
pretrained_model_link_or_path
,
**
kwargs
):
def
from_single_file
(
cls
,
pretrained_model_link_or_path
,
**
kwargs
):
r
"""
r
"""
Instantiate a [`
ControlNetModel
`] from pretrained
Control
Net weights saved in the original `.ckpt` or
Instantiate a [`
StableCascadeUNet
`] from pretrained
StableCascadeU
Net weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
Parameters:
...
@@ -951,6 +951,10 @@ class FromOriginalUNetMixin:
...
@@ -951,6 +951,10 @@ class FromOriginalUNetMixin:
Can be used to overwrite load and saveable variables of the model.
Can be used to overwrite load and saveable variables of the model.
"""
"""
class_name
=
cls
.
__name__
if
class_name
!=
"StableCascadeUNet"
:
raise
ValueError
(
"FromOriginalUNetMixin is currently only compatible with StableCascadeUNet"
)
config
=
kwargs
.
pop
(
"config"
,
None
)
config
=
kwargs
.
pop
(
"config"
,
None
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
...
@@ -961,10 +965,6 @@ class FromOriginalUNetMixin:
...
@@ -961,10 +965,6 @@ class FromOriginalUNetMixin:
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
torch_dtype
=
kwargs
.
pop
(
"torch_dtype"
,
None
)
torch_dtype
=
kwargs
.
pop
(
"torch_dtype"
,
None
)
class_name
=
cls
.
__name__
if
class_name
!=
"StableCascadeUNet"
:
raise
ValueError
(
"FromOriginalUNetMixin is currently only compatible with StableCascadeUNet"
)
checkpoint
=
load_single_file_model_checkpoint
(
checkpoint
=
load_single_file_model_checkpoint
(
pretrained_model_link_or_path
,
pretrained_model_link_or_path
,
resume_download
=
resume_download
,
resume_download
=
resume_download
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment