Unverified Commit ef29b24f authored by Vladimir Mandic's avatar Vladimir Mandic Committed by GitHub
Browse files

allow loading of sd models from safetensors without online lookups using local config files (#5019)

finish config_files implementation
parent 8dc93ad3
...@@ -154,6 +154,7 @@ if __name__ == "__main__": ...@@ -154,6 +154,7 @@ if __name__ == "__main__":
pipe = download_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=args.checkpoint_path, checkpoint_path_or_dict=args.checkpoint_path,
original_config_file=args.original_config_file, original_config_file=args.original_config_file,
config_files=args.config_files,
image_size=args.image_size, image_size=args.image_size,
prediction_type=args.prediction_type, prediction_type=args.prediction_type,
model_type=args.pipeline_type, model_type=args.pipeline_type,
......
...@@ -2099,6 +2099,7 @@ class FromSingleFileMixin: ...@@ -2099,6 +2099,7 @@ class FromSingleFileMixin:
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
original_config_file = kwargs.pop("original_config_file", None) original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
...@@ -2216,6 +2217,7 @@ class FromSingleFileMixin: ...@@ -2216,6 +2217,7 @@ class FromSingleFileMixin:
vae=vae, vae=vae,
tokenizer=tokenizer, tokenizer=tokenizer,
original_config_file=original_config_file, original_config_file=original_config_file,
config_files=config_files,
) )
if torch_dtype is not None: if torch_dtype is not None:
......
...@@ -1256,24 +1256,36 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1256,24 +1256,36 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
config_url = None
# model_type = "v1" # model_type = "v1"
if config_files is not None and "v1" in config_files:
original_config_file = config_files["v1"]
else:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
# model_type = "v2" # model_type = "v2"
if config_files is not None and "v2" in config_files:
original_config_file = config_files["v2"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
if global_step == 110000: if global_step == 110000:
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
upcast_attention = True upcast_attention = True
elif key_name_sd_xl_base in checkpoint: elif key_name_sd_xl_base in checkpoint:
# only base xl has two text embedders # only base xl has two text embedders
if config_files is not None and "xl" in config_files:
original_config_file = config_files["xl"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
elif key_name_sd_xl_refiner in checkpoint: elif key_name_sd_xl_refiner in checkpoint:
# only refiner xl has embedder and one text embedders # only refiner xl has embedder and one text embedders
if config_files is not None and "xl_refiner" in config_files:
original_config_file = config_files["xl_refiner"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content) original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment