Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a0422ed0
Unverified
Commit
a0422ed0
authored
Jul 25, 2023
by
Patrick von Platen
Committed by
GitHub
Jul 25, 2023
Browse files
[From Single File] Allow vae to be loaded (#4242)
* Allow vae to be loaded * up
parent
3dd33937
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
2 deletions
+11
-2
src/diffusers/loaders.py
src/diffusers/loaders.py
+5
-0
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
...diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+6
-2
No files found.
src/diffusers/loaders.py
View file @
a0422ed0
...
@@ -1410,6 +1410,9 @@ class FromSingleFileMixin:
...
@@ -1410,6 +1410,9 @@ class FromSingleFileMixin:
An instance of `CLIPTextModel` to use, specifically the
An instance of `CLIPTextModel` to use, specifically the
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
vae (`AutoencoderKL`, *optional*, defaults to `None`):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
of `CLIPTokenizer` by itself if needed.
of `CLIPTokenizer` by itself if needed.
...
@@ -1458,6 +1461,7 @@ class FromSingleFileMixin:
...
@@ -1458,6 +1461,7 @@ class FromSingleFileMixin:
load_safety_checker
=
kwargs
.
pop
(
"load_safety_checker"
,
True
)
load_safety_checker
=
kwargs
.
pop
(
"load_safety_checker"
,
True
)
prediction_type
=
kwargs
.
pop
(
"prediction_type"
,
None
)
prediction_type
=
kwargs
.
pop
(
"prediction_type"
,
None
)
text_encoder
=
kwargs
.
pop
(
"text_encoder"
,
None
)
text_encoder
=
kwargs
.
pop
(
"text_encoder"
,
None
)
vae
=
kwargs
.
pop
(
"vae"
,
None
)
controlnet
=
kwargs
.
pop
(
"controlnet"
,
None
)
controlnet
=
kwargs
.
pop
(
"controlnet"
,
None
)
tokenizer
=
kwargs
.
pop
(
"tokenizer"
,
None
)
tokenizer
=
kwargs
.
pop
(
"tokenizer"
,
None
)
...
@@ -1548,6 +1552,7 @@ class FromSingleFileMixin:
...
@@ -1548,6 +1552,7 @@ class FromSingleFileMixin:
load_safety_checker
=
load_safety_checker
,
load_safety_checker
=
load_safety_checker
,
prediction_type
=
prediction_type
,
prediction_type
=
prediction_type
,
text_encoder
=
text_encoder
,
text_encoder
=
text_encoder
,
vae
=
vae
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
)
)
...
...
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
View file @
a0422ed0
...
@@ -1107,6 +1107,7 @@ def download_from_original_stable_diffusion_ckpt(
...
@@ -1107,6 +1107,7 @@ def download_from_original_stable_diffusion_ckpt(
pipeline_class
:
DiffusionPipeline
=
None
,
pipeline_class
:
DiffusionPipeline
=
None
,
local_files_only
=
False
,
local_files_only
=
False
,
vae_path
=
None
,
vae_path
=
None
,
vae
=
None
,
text_encoder
=
None
,
text_encoder
=
None
,
tokenizer
=
None
,
tokenizer
=
None
,
)
->
DiffusionPipeline
:
)
->
DiffusionPipeline
:
...
@@ -1156,6 +1157,9 @@ def download_from_original_stable_diffusion_ckpt(
...
@@ -1156,6 +1157,9 @@ def download_from_original_stable_diffusion_ckpt(
The pipeline class to use. Pass `None` to determine automatically.
The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
Whether or not to only look at local files (i.e., do not try to download the model).
vae (`AutoencoderKL`, *optional*, defaults to `None`):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
...
@@ -1361,7 +1365,7 @@ def download_from_original_stable_diffusion_ckpt(
...
@@ -1361,7 +1365,7 @@ def download_from_original_stable_diffusion_ckpt(
unet
.
load_state_dict
(
converted_unet_checkpoint
)
unet
.
load_state_dict
(
converted_unet_checkpoint
)
# Convert the VAE model.
# Convert the VAE model.
if
vae_path
is
None
:
if
vae_path
is
None
and
vae
is
None
:
vae_config
=
create_vae_diffusers_config
(
original_config
,
image_size
=
image_size
)
vae_config
=
create_vae_diffusers_config
(
original_config
,
image_size
=
image_size
)
converted_vae_checkpoint
=
convert_ldm_vae_checkpoint
(
checkpoint
,
vae_config
)
converted_vae_checkpoint
=
convert_ldm_vae_checkpoint
(
checkpoint
,
vae_config
)
...
@@ -1385,7 +1389,7 @@ def download_from_original_stable_diffusion_ckpt(
...
@@ -1385,7 +1389,7 @@ def download_from_original_stable_diffusion_ckpt(
set_module_tensor_to_device
(
vae
,
param_name
,
"cpu"
,
value
=
param
)
set_module_tensor_to_device
(
vae
,
param_name
,
"cpu"
,
value
=
param
)
else
:
else
:
vae
.
load_state_dict
(
converted_vae_checkpoint
)
vae
.
load_state_dict
(
converted_vae_checkpoint
)
el
s
e
:
el
if
vae
is
Non
e
:
vae
=
AutoencoderKL
.
from_pretrained
(
vae_path
)
vae
=
AutoencoderKL
.
from_pretrained
(
vae_path
)
if
model_type
==
"FrozenOpenCLIPEmbedder"
:
if
model_type
==
"FrozenOpenCLIPEmbedder"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment