Unverified Commit f21415d1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Update conversion script to correctly handle SD 2 (#1511)

* Conversion SD 2

* finish
parent 22b9cb08
...@@ -33,6 +33,7 @@ from diffusers import ( ...@@ -33,6 +33,7 @@ from diffusers import (
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
HeunDiscreteScheduler,
LDMTextToImagePipeline, LDMTextToImagePipeline,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
...@@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int): ...@@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int):
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim = [5, 10, 20, 20]
config = dict( config = dict(
sample_size=image_size // vae_scale_factor, sample_size=image_size // vae_scale_factor,
in_channels=unet_params.in_channels, in_channels=unet_params.in_channels,
...@@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int): ...@@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
block_out_channels=tuple(block_out_channels), block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_res_blocks, layers_per_block=unet_params.num_res_blocks,
cross_attention_dim=unet_params.context_dim, cross_attention_dim=unet_params.context_dim,
attention_head_dim=unet_params.num_heads, attention_head_dim=head_dim,
use_linear_projection=use_linear_projection,
) )
return config return config
...@@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint): ...@@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
return text_model return text_model
def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
# SKIP for now - need openclip -> HF conversion script here
# keys = list(checkpoint.keys())
#
# text_model_dict = {}
# for key in keys:
# if key.startswith("cond_stage_model.model.transformer"):
# text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
#
# text_model.load_state_dict(text_model_dict)
return text_model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -657,13 +684,22 @@ if __name__ == "__main__": ...@@ -657,13 +684,22 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--image_size", "--image_size",
default=512, default=None,
type=int, type=int,
help=( help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2." " Base. Use 768 for Stable Diffusion v2."
), ),
) )
parser.add_argument(
"--prediction_type",
default=None,
type=int,
help=(
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
" Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
),
)
parser.add_argument( parser.add_argument(
"--extract_ema", "--extract_ema",
action="store_true", action="store_true",
...@@ -674,65 +710,96 @@ if __name__ == "__main__": ...@@ -674,65 +710,96 @@ if __name__ == "__main__":
), ),
) )
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args() args = parser.parse_args()
image_size = args.image_size
prediction_type = args.prediction_type
checkpoint = torch.load(args.checkpoint_path)
global_step = checkpoint["global_step"]
checkpoint = checkpoint["state_dict"]
if args.original_config_file is None: if args.original_config_file is None:
os.system( key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
) if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
args.original_config_file = "./v1-inference.yaml" # model_type = "v2"
os.system(
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
)
args.original_config_file = "./v2-inference-v.yaml"
else:
# model_type = "v1"
os.system(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
args.original_config_file = "./v1-inference.yaml"
original_config = OmegaConf.load(args.original_config_file) original_config = OmegaConf.load(args.original_config_file)
checkpoint = torch.load(args.checkpoint_path) if (
checkpoint = checkpoint["state_dict"] "parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
if image_size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
else:
if prediction_type is None:
prediction_type = "epsilon"
if image_size is None:
image_size = 512
num_train_timesteps = original_config.model.params.timesteps num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end beta_end = original_config.model.params.linear_end
scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
)
if args.scheduler_type == "pndm": if args.scheduler_type == "pndm":
scheduler = PNDMScheduler( config = dict(scheduler.config)
beta_end=beta_end, config["skip_prk_steps"] = True
beta_schedule="scaled_linear", scheduler = PNDMScheduler.from_config(config)
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
skip_prk_steps=True,
)
elif args.scheduler_type == "lms": elif args.scheduler_type == "lms":
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "euler": elif args.scheduler_type == "euler":
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "euler-ancestral": elif args.scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler( scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "dpm": elif args.scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler( scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "ddim": elif args.scheduler_type == "ddim":
scheduler = DDIMScheduler( scheduler = scheduler
beta_start=beta_start,
beta_end=beta_end,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
else: else:
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=args.image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet = UNet2DConditionModel(**unet_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint( converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
) )
unet = UNet2DConditionModel(**unet_config)
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model. # Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=args.image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
...@@ -740,7 +807,20 @@ if __name__ == "__main__": ...@@ -740,7 +807,20 @@ if __name__ == "__main__":
# Convert the text model. # Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenCLIPEmbedder": if text_model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
elif text_model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint) text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment