Unverified Commit 186689af authored by Víctor Martínez's avatar Víctor Martínez Committed by GitHub
Browse files

fix: un-existing tmp config file in linux, avoid unnecessary disk IO (#2591)

parent cbbad0af
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
""" Conversion script for the Stable Diffusion checkpoints.""" """ Conversion script for the Stable Diffusion checkpoints."""
import os
import re import re
import tempfile from io import BytesIO
from typing import Optional from typing import Optional
import requests import requests
...@@ -1046,29 +1045,21 @@ def load_pipeline_from_original_stable_diffusion_ckpt( ...@@ -1046,29 +1045,21 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
with tempfile.TemporaryDirectory() as tmpdir:
if original_config_file is None: if original_config_file is None:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
original_config_file = os.path.join(tmpdir, "inference.yaml") # model_type = "v1"
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
if not os.path.isfile("v2-inference-v.yaml"):
# model_type = "v2" # model_type = "v2"
r = requests.get( config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
" https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
)
open(original_config_file, "wb").write(r.content)
if global_step == 110000: if global_step == 110000:
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
upcast_attention = True upcast_attention = True
else:
if not os.path.isfile("v1-inference.yaml"): original_config_file = BytesIO(requests.get(config_url).content)
# model_type = "v1"
r = requests.get(
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
open(original_config_file, "wb").write(r.content)
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment