Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5729829c
Unverified
Commit
5729829c
authored
Jul 17, 2023
by
Patrick von Platen
Committed by
GitHub
Jul 17, 2023
Browse files
[From single file] Make accelerate optional (#4132)
* Make accelerate optional * make accelerate optional
parent
e27500b7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
19 deletions
+31
-19
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
...diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+31
-19
No files found.
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
View file @
5729829c
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
""" Conversion script for the Stable Diffusion checkpoints."""
""" Conversion script for the Stable Diffusion checkpoints."""
import
re
import
re
from
contextlib
import
nullcontext
from
io
import
BytesIO
from
io
import
BytesIO
from
typing
import
Optional
from
typing
import
Optional
...
@@ -779,7 +780,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
...
@@ -779,7 +780,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
config_name
=
"openai/clip-vit-large-patch14"
config_name
=
"openai/clip-vit-large-patch14"
config
=
CLIPTextConfig
.
from_pretrained
(
config_name
)
config
=
CLIPTextConfig
.
from_pretrained
(
config_name
)
with
init_empty_weights
():
ctx
=
init_empty_weights
if
is_accelerate_available
()
else
nullcontext
with
ctx
():
text_model
=
CLIPTextModel
(
config
)
text_model
=
CLIPTextModel
(
config
)
keys
=
list
(
checkpoint
.
keys
())
keys
=
list
(
checkpoint
.
keys
())
...
@@ -793,8 +795,11 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
...
@@ -793,8 +795,11 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
if
key
.
startswith
(
prefix
):
if
key
.
startswith
(
prefix
):
text_model_dict
[
key
[
len
(
prefix
+
"."
)
:]]
=
checkpoint
[
key
]
text_model_dict
[
key
[
len
(
prefix
+
"."
)
:]]
=
checkpoint
[
key
]
for
param_name
,
param
in
text_model_dict
.
items
():
if
is_accelerate_available
():
set_module_tensor_to_device
(
text_model
,
param_name
,
"cpu"
,
value
=
param
)
for
param_name
,
param
in
text_model_dict
.
items
():
set_module_tensor_to_device
(
text_model
,
param_name
,
"cpu"
,
value
=
param
)
else
:
text_model
.
load_state_dict
(
text_model_dict
)
return
text_model
return
text_model
...
@@ -900,7 +905,8 @@ def convert_open_clip_checkpoint(
...
@@ -900,7 +905,8 @@ def convert_open_clip_checkpoint(
# )
# )
config
=
CLIPTextConfig
.
from_pretrained
(
config_name
,
**
config_kwargs
)
config
=
CLIPTextConfig
.
from_pretrained
(
config_name
,
**
config_kwargs
)
with
init_empty_weights
():
ctx
=
init_empty_weights
if
is_accelerate_available
()
else
nullcontext
with
ctx
():
text_model
=
CLIPTextModelWithProjection
(
config
)
if
has_projection
else
CLIPTextModel
(
config
)
text_model
=
CLIPTextModelWithProjection
(
config
)
if
has_projection
else
CLIPTextModel
(
config
)
keys
=
list
(
checkpoint
.
keys
())
keys
=
list
(
checkpoint
.
keys
())
...
@@ -950,8 +956,11 @@ def convert_open_clip_checkpoint(
...
@@ -950,8 +956,11 @@ def convert_open_clip_checkpoint(
text_model_dict
[
new_key
]
=
checkpoint
[
key
]
text_model_dict
[
new_key
]
=
checkpoint
[
key
]
for
param_name
,
param
in
text_model_dict
.
items
():
if
is_accelerate_available
():
set_module_tensor_to_device
(
text_model
,
param_name
,
"cpu"
,
value
=
param
)
for
param_name
,
param
in
text_model_dict
.
items
():
set_module_tensor_to_device
(
text_model
,
param_name
,
"cpu"
,
value
=
param
)
else
:
text_model
.
load_state_dict
(
text_model_dict
)
return
text_model
return
text_model
...
@@ -1172,11 +1181,6 @@ def download_from_original_stable_diffusion_ckpt(
...
@@ -1172,11 +1181,6 @@ def download_from_original_stable_diffusion_ckpt(
StableUnCLIPPipeline
,
StableUnCLIPPipeline
,
)
)
if
not
is_accelerate_available
():
raise
ImportError
(
"To correctly use `from_single_file`, please make sure that `accelerate` is installed. You can install it with `pip install accelerate`."
)
if
pipeline_class
is
None
:
if
pipeline_class
is
None
:
pipeline_class
=
StableDiffusionPipeline
pipeline_class
=
StableDiffusionPipeline
...
@@ -1346,15 +1350,19 @@ def download_from_original_stable_diffusion_ckpt(
...
@@ -1346,15 +1350,19 @@ def download_from_original_stable_diffusion_ckpt(
# Convert the UNet2DConditionModel model.
# Convert the UNet2DConditionModel model.
unet_config
=
create_unet_diffusers_config
(
original_config
,
image_size
=
image_size
)
unet_config
=
create_unet_diffusers_config
(
original_config
,
image_size
=
image_size
)
unet_config
[
"upcast_attention"
]
=
upcast_attention
unet_config
[
"upcast_attention"
]
=
upcast_attention
with
init_empty_weights
():
unet
=
UNet2DConditionModel
(
**
unet_config
)
converted_unet_checkpoint
=
convert_ldm_unet_checkpoint
(
converted_unet_checkpoint
=
convert_ldm_unet_checkpoint
(
checkpoint
,
unet_config
,
path
=
checkpoint_path
,
extract_ema
=
extract_ema
checkpoint
,
unet_config
,
path
=
checkpoint_path
,
extract_ema
=
extract_ema
)
)
for
param_name
,
param
in
converted_unet_checkpoint
.
items
():
ctx
=
init_empty_weights
if
is_accelerate_available
()
else
nullcontext
set_module_tensor_to_device
(
unet
,
param_name
,
"cpu"
,
value
=
param
)
with
ctx
():
unet
=
UNet2DConditionModel
(
**
unet_config
)
if
is_accelerate_available
():
for
param_name
,
param
in
converted_unet_checkpoint
.
items
():
set_module_tensor_to_device
(
unet
,
param_name
,
"cpu"
,
value
=
param
)
else
:
unet
.
load_state_dict
(
converted_unet_checkpoint
)
# Convert the VAE model.
# Convert the VAE model.
if
vae_path
is
None
:
if
vae_path
is
None
:
...
@@ -1372,11 +1380,15 @@ def download_from_original_stable_diffusion_ckpt(
...
@@ -1372,11 +1380,15 @@ def download_from_original_stable_diffusion_ckpt(
vae_config
[
"scaling_factor"
]
=
vae_scaling_factor
vae_config
[
"scaling_factor"
]
=
vae_scaling_factor
with
init_empty_weights
():
ctx
=
init_empty_weights
if
is_accelerate_available
()
else
nullcontext
with
ctx
():
vae
=
AutoencoderKL
(
**
vae_config
)
vae
=
AutoencoderKL
(
**
vae_config
)
for
param_name
,
param
in
converted_vae_checkpoint
.
items
():
if
is_accelerate_available
():
set_module_tensor_to_device
(
vae
,
param_name
,
"cpu"
,
value
=
param
)
for
param_name
,
param
in
converted_vae_checkpoint
.
items
():
set_module_tensor_to_device
(
vae
,
param_name
,
"cpu"
,
value
=
param
)
else
:
vae
.
load_state_dict
(
converted_vae_checkpoint
)
else
:
else
:
vae
=
AutoencoderKL
.
from_pretrained
(
vae_path
)
vae
=
AutoencoderKL
.
from_pretrained
(
vae_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment