"examples/git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "6ce55e4b011275e43404034832b40648b1483ff6"
Unverified Commit ef29b24f authored by Vladimir Mandic's avatar Vladimir Mandic Committed by GitHub
Browse files

allow loading of sd models from safetensors without online lookups using local config files (#5019)

finish config_files implementation
parent 8dc93ad3
...@@ -154,6 +154,7 @@ if __name__ == "__main__": ...@@ -154,6 +154,7 @@ if __name__ == "__main__":
pipe = download_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=args.checkpoint_path, checkpoint_path_or_dict=args.checkpoint_path,
original_config_file=args.original_config_file, original_config_file=args.original_config_file,
config_files=args.config_files,
image_size=args.image_size, image_size=args.image_size,
prediction_type=args.prediction_type, prediction_type=args.prediction_type,
model_type=args.pipeline_type, model_type=args.pipeline_type,
......
...@@ -2099,6 +2099,7 @@ class FromSingleFileMixin: ...@@ -2099,6 +2099,7 @@ class FromSingleFileMixin:
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
original_config_file = kwargs.pop("original_config_file", None) original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
...@@ -2216,6 +2217,7 @@ class FromSingleFileMixin: ...@@ -2216,6 +2217,7 @@ class FromSingleFileMixin:
vae=vae, vae=vae,
tokenizer=tokenizer, tokenizer=tokenizer,
original_config_file=original_config_file, original_config_file=original_config_file,
config_files=config_files,
) )
if torch_dtype is not None: if torch_dtype is not None:
......
...@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
config_url = None
# model_type = "v1" # model_type = "v1"
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" if config_files is not None and "v1" in config_files:
original_config_file = config_files["v1"]
else:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
# model_type = "v2" # model_type = "v2"
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" if config_files is not None and "v2" in config_files:
original_config_file = config_files["v2"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
if global_step == 110000: if global_step == 110000:
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
upcast_attention = True upcast_attention = True
elif key_name_sd_xl_base in checkpoint: elif key_name_sd_xl_base in checkpoint:
# only base xl has two text embedders # only base xl has two text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" if config_files is not None and "xl" in config_files:
original_config_file = config_files["xl"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
elif key_name_sd_xl_refiner in checkpoint: elif key_name_sd_xl_refiner in checkpoint:
# only refiner xl has embedder and one text embedders # only refiner xl has embedder and one text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" if config_files is not None and "xl_refiner" in config_files:
original_config_file = config_files["xl_refiner"]
original_config_file = BytesIO(requests.get(config_url).content) else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment