Unverified Commit 186689af authored by Víctor Martínez's avatar Víctor Martínez Committed by GitHub
Browse files

fix: un-existing tmp config file in linux, avoid unnecessary disk IO (#2591)

parent cbbad0af
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
""" Conversion script for the Stable Diffusion checkpoints.""" """ Conversion script for the Stable Diffusion checkpoints."""
import os
import re import re
import tempfile from io import BytesIO
from typing import Optional from typing import Optional
import requests import requests
...@@ -1046,31 +1045,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt( ...@@ -1046,31 +1045,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
with tempfile.TemporaryDirectory() as tmpdir: if original_config_file is None:
if original_config_file is None: key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
# model_type = "v1"
original_config_file = os.path.join(tmpdir, "inference.yaml") config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
if not os.path.isfile("v2-inference-v.yaml"): if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
# model_type = "v2" # model_type = "v2"
r = requests.get( config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
" https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
) if global_step == 110000:
open(original_config_file, "wb").write(r.content) # v2.1 needs to upcast attention
upcast_attention = True
if global_step == 110000:
# v2.1 needs to upcast attention original_config_file = BytesIO(requests.get(config_url).content)
upcast_attention = True
else: original_config = OmegaConf.load(original_config_file)
if not os.path.isfile("v1-inference.yaml"):
# model_type = "v1"
r = requests.get(
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
open(original_config_file, "wb").write(r.content)
original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None: if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment