Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f21415d1
Unverified
Commit
f21415d1
authored
Dec 02, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 02, 2022
Browse files
Update conversion script to correctly handle SD 2 (#1511)
* Conversion SD 2 * finish
parent
22b9cb08
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
115 additions
and
35 deletions
+115
-35
scripts/convert_original_stable_diffusion_to_diffusers.py
scripts/convert_original_stable_diffusion_to_diffusers.py
+115
-35
No files found.
scripts/convert_original_stable_diffusion_to_diffusers.py
View file @
f21415d1
...
@@ -33,6 +33,7 @@ from diffusers import (
...
@@ -33,6 +33,7 @@ from diffusers import (
DPMSolverMultistepScheduler
,
DPMSolverMultistepScheduler
,
EulerAncestralDiscreteScheduler
,
EulerAncestralDiscreteScheduler
,
EulerDiscreteScheduler
,
EulerDiscreteScheduler
,
HeunDiscreteScheduler
,
LDMTextToImagePipeline
,
LDMTextToImagePipeline
,
LMSDiscreteScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
,
PNDMScheduler
,
...
@@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int):
...
@@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int):
vae_scale_factor
=
2
**
(
len
(
vae_params
.
ch_mult
)
-
1
)
vae_scale_factor
=
2
**
(
len
(
vae_params
.
ch_mult
)
-
1
)
head_dim
=
unet_params
.
num_heads
if
"num_heads"
in
unet_params
else
None
use_linear_projection
=
(
unet_params
.
use_linear_in_transformer
if
"use_linear_in_transformer"
in
unet_params
else
False
)
if
use_linear_projection
:
# stable diffusion 2-base-512 and 2-768
if
head_dim
is
None
:
head_dim
=
[
5
,
10
,
20
,
20
]
config
=
dict
(
config
=
dict
(
sample_size
=
image_size
//
vae_scale_factor
,
sample_size
=
image_size
//
vae_scale_factor
,
in_channels
=
unet_params
.
in_channels
,
in_channels
=
unet_params
.
in_channels
,
...
@@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
...
@@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
block_out_channels
=
tuple
(
block_out_channels
),
block_out_channels
=
tuple
(
block_out_channels
),
layers_per_block
=
unet_params
.
num_res_blocks
,
layers_per_block
=
unet_params
.
num_res_blocks
,
cross_attention_dim
=
unet_params
.
context_dim
,
cross_attention_dim
=
unet_params
.
context_dim
,
attention_head_dim
=
unet_params
.
num_heads
,
attention_head_dim
=
head_dim
,
use_linear_projection
=
use_linear_projection
,
)
)
return
config
return
config
...
@@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
...
@@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
return
text_model
return
text_model
def
convert_open_clip_checkpoint
(
checkpoint
):
text_model
=
CLIPTextModel
.
from_pretrained
(
"stabilityai/stable-diffusion-2"
,
subfolder
=
"text_encoder"
)
# SKIP for now - need openclip -> HF conversion script here
# keys = list(checkpoint.keys())
#
# text_model_dict = {}
# for key in keys:
# if key.startswith("cond_stage_model.model.transformer"):
# text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
#
# text_model.load_state_dict(text_model_dict)
return
text_model
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -657,13 +684,22 @@ if __name__ == "__main__":
...
@@ -657,13 +684,22 @@ if __name__ == "__main__":
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--image_size"
,
"--image_size"
,
default
=
512
,
default
=
None
,
type
=
int
,
type
=
int
,
help
=
(
help
=
(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
" Base. Use 768 for Stable Diffusion v2."
),
),
)
)
parser
.
add_argument
(
"--prediction_type"
,
default
=
None
,
type
=
int
,
help
=
(
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
" Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
),
)
parser
.
add_argument
(
parser
.
add_argument
(
"--extract_ema"
,
"--extract_ema"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -674,10 +710,26 @@ if __name__ == "__main__":
...
@@ -674,10 +710,26 @@ if __name__ == "__main__":
),
),
)
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
image_size
=
args
.
image_size
prediction_type
=
args
.
prediction_type
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
global_step
=
checkpoint
[
"global_step"
]
checkpoint
=
checkpoint
[
"state_dict"
]
if
args
.
original_config_file
is
None
:
if
args
.
original_config_file
is
None
:
key_name
=
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if
key_name
in
checkpoint
and
checkpoint
[
key_name
].
shape
[
-
1
]
==
1024
:
# model_type = "v2"
os
.
system
(
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
)
args
.
original_config_file
=
"./v2-inference-v.yaml"
else
:
# model_type = "v1"
os
.
system
(
os
.
system
(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
)
...
@@ -685,54 +737,69 @@ if __name__ == "__main__":
...
@@ -685,54 +737,69 @@ if __name__ == "__main__":
original_config
=
OmegaConf
.
load
(
args
.
original_config_file
)
original_config
=
OmegaConf
.
load
(
args
.
original_config_file
)
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
if
(
checkpoint
=
checkpoint
[
"state_dict"
]
"parameterization"
in
original_config
[
"model"
][
"params"
]
and
original_config
[
"model"
][
"params"
][
"parameterization"
]
==
"v"
):
if
prediction_type
is
None
:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
prediction_type
=
"epsilon"
if
global_step
==
875000
else
"v_prediction"
if
image_size
is
None
:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size
=
512
if
global_step
==
875000
else
768
else
:
if
prediction_type
is
None
:
prediction_type
=
"epsilon"
if
image_size
is
None
:
image_size
=
512
num_train_timesteps
=
original_config
.
model
.
params
.
timesteps
num_train_timesteps
=
original_config
.
model
.
params
.
timesteps
beta_start
=
original_config
.
model
.
params
.
linear_start
beta_start
=
original_config
.
model
.
params
.
linear_start
beta_end
=
original_config
.
model
.
params
.
linear_end
beta_end
=
original_config
.
model
.
params
.
linear_end
if
args
.
scheduler_type
==
"pndm"
:
scheduler
=
PND
MScheduler
(
scheduler
=
DDI
MScheduler
(
beta_end
=
beta_end
,
beta_end
=
beta_end
,
beta_schedule
=
"scaled_linear"
,
beta_schedule
=
"scaled_linear"
,
beta_start
=
beta_start
,
beta_start
=
beta_start
,
num_train_timesteps
=
num_train_timesteps
,
num_train_timesteps
=
num_train_timesteps
,
skip_prk_steps
=
True
,
steps_offset
=
1
,
clip_sample
=
False
,
set_alpha_to_one
=
False
,
prediction_type
=
prediction_type
,
)
)
if
args
.
scheduler_type
==
"pndm"
:
config
=
dict
(
scheduler
.
config
)
config
[
"skip_prk_steps"
]
=
True
scheduler
=
PNDMScheduler
.
from_config
(
config
)
elif
args
.
scheduler_type
==
"lms"
:
elif
args
.
scheduler_type
==
"lms"
:
scheduler
=
LMSDiscreteScheduler
(
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
"scaled_linear"
)
scheduler
=
LMSDiscreteScheduler
.
from_config
(
scheduler
.
config
)
elif
args
.
scheduler_type
==
"heun"
:
scheduler
=
HeunDiscreteScheduler
.
from_config
(
scheduler
.
config
)
elif
args
.
scheduler_type
==
"euler"
:
elif
args
.
scheduler_type
==
"euler"
:
scheduler
=
EulerDiscreteScheduler
(
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
"scaled_linear"
)
scheduler
=
EulerDiscreteScheduler
.
from_config
(
scheduler
.
config
)
elif
args
.
scheduler_type
==
"euler-ancestral"
:
elif
args
.
scheduler_type
==
"euler-ancestral"
:
scheduler
=
EulerAncestralDiscreteScheduler
(
scheduler
=
EulerAncestralDiscreteScheduler
.
from_config
(
scheduler
.
config
)
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
"scaled_linear"
)
elif
args
.
scheduler_type
==
"dpm"
:
elif
args
.
scheduler_type
==
"dpm"
:
scheduler
=
DPMSolverMultistepScheduler
(
scheduler
=
DPMSolverMultistepScheduler
.
from_config
(
scheduler
.
config
)
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
"scaled_linear"
)
elif
args
.
scheduler_type
==
"ddim"
:
elif
args
.
scheduler_type
==
"ddim"
:
scheduler
=
DDIMScheduler
(
scheduler
=
scheduler
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
"scaled_linear"
,
clip_sample
=
False
,
set_alpha_to_one
=
False
,
)
else
:
else
:
raise
ValueError
(
f
"Scheduler of type
{
args
.
scheduler_type
}
doesn't exist!"
)
raise
ValueError
(
f
"Scheduler of type
{
args
.
scheduler_type
}
doesn't exist!"
)
# Convert the UNet2DConditionModel model.
# Convert the UNet2DConditionModel model.
unet_config
=
create_unet_diffusers_config
(
original_config
,
image_size
=
args
.
image_size
)
unet_config
=
create_unet_diffusers_config
(
original_config
,
image_size
=
image_size
)
unet
=
UNet2DConditionModel
(
**
unet_config
)
converted_unet_checkpoint
=
convert_ldm_unet_checkpoint
(
converted_unet_checkpoint
=
convert_ldm_unet_checkpoint
(
checkpoint
,
unet_config
,
path
=
args
.
checkpoint_path
,
extract_ema
=
args
.
extract_ema
checkpoint
,
unet_config
,
path
=
args
.
checkpoint_path
,
extract_ema
=
args
.
extract_ema
)
)
unet
=
UNet2DConditionModel
(
**
unet_config
)
unet
.
load_state_dict
(
converted_unet_checkpoint
)
unet
.
load_state_dict
(
converted_unet_checkpoint
)
# Convert the VAE model.
# Convert the VAE model.
vae_config
=
create_vae_diffusers_config
(
original_config
,
image_size
=
args
.
image_size
)
vae_config
=
create_vae_diffusers_config
(
original_config
,
image_size
=
image_size
)
converted_vae_checkpoint
=
convert_ldm_vae_checkpoint
(
checkpoint
,
vae_config
)
converted_vae_checkpoint
=
convert_ldm_vae_checkpoint
(
checkpoint
,
vae_config
)
vae
=
AutoencoderKL
(
**
vae_config
)
vae
=
AutoencoderKL
(
**
vae_config
)
...
@@ -740,7 +807,20 @@ if __name__ == "__main__":
...
@@ -740,7 +807,20 @@ if __name__ == "__main__":
# Convert the text model.
# Convert the text model.
text_model_type
=
original_config
.
model
.
params
.
cond_stage_config
.
target
.
split
(
"."
)[
-
1
]
text_model_type
=
original_config
.
model
.
params
.
cond_stage_config
.
target
.
split
(
"."
)[
-
1
]
if
text_model_type
==
"FrozenCLIPEmbedder"
:
if
text_model_type
==
"FrozenOpenCLIPEmbedder"
:
text_model
=
convert_open_clip_checkpoint
(
checkpoint
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"stabilityai/stable-diffusion-2"
,
subfolder
=
"tokenizer"
)
pipe
=
StableDiffusionPipeline
(
vae
=
vae
,
text_encoder
=
text_model
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
None
,
feature_extractor
=
None
,
requires_safety_checker
=
False
,
)
elif
text_model_type
==
"FrozenCLIPEmbedder"
:
text_model
=
convert_ldm_clip_checkpoint
(
checkpoint
)
text_model
=
convert_ldm_clip_checkpoint
(
checkpoint
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"openai/clip-vit-large-patch14"
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
"openai/clip-vit-large-patch14"
)
safety_checker
=
StableDiffusionSafetyChecker
.
from_pretrained
(
"CompVis/stable-diffusion-safety-checker"
)
safety_checker
=
StableDiffusionSafetyChecker
.
from_pretrained
(
"CompVis/stable-diffusion-safety-checker"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment