Unverified Commit 8bff7823 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Improve single loading file (#4041)

* start improving single file load

* Fix more

* start improving single file load

* Fix sd 2.1

* further improve from_single_file
parent 66328236
...@@ -1389,7 +1389,7 @@ class FromSingleFileMixin: ...@@ -1389,7 +1389,7 @@ class FromSingleFileMixin:
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False) extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", 512) image_size = kwargs.pop("image_size", None)
scheduler_type = kwargs.pop("scheduler_type", "pndm") scheduler_type = kwargs.pop("scheduler_type", "pndm")
num_in_channels = kwargs.pop("num_in_channels", None) num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None) upcast_attention = kwargs.pop("upcast_attention", None)
......
...@@ -24,6 +24,7 @@ from transformers import ( ...@@ -24,6 +24,7 @@ from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
BertTokenizerFast, BertTokenizerFast,
CLIPImageProcessor, CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel, CLIPTextModel,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
...@@ -48,7 +49,7 @@ from ...schedulers import ( ...@@ -48,7 +49,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
UnCLIPScheduler, UnCLIPScheduler,
) )
from ...utils import is_omegaconf_available, is_safetensors_available, logging from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging
from ...utils.import_utils import BACKENDS_MAPPING from ...utils.import_utils import BACKENDS_MAPPING
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder from ..paint_by_example import PaintByExampleImageEncoder
...@@ -57,6 +58,10 @@ from .safety_checker import StableDiffusionSafetyChecker ...@@ -57,6 +58,10 @@ from .safety_checker import StableDiffusionSafetyChecker
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -770,11 +775,12 @@ def convert_ldm_bert_checkpoint(checkpoint, config): ...@@ -770,11 +775,12 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
text_model = ( if text_encoder is None:
CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) config_name = "openai/clip-vit-large-patch14"
if text_encoder is None config = CLIPTextConfig.from_pretrained(config_name)
else text_encoder
) with init_empty_weights():
text_model = CLIPTextModel(config)
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
...@@ -787,7 +793,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder ...@@ -787,7 +793,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
if key.startswith(prefix): if key.startswith(prefix):
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
text_model.load_state_dict(text_model_dict) for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
return text_model return text_model
...@@ -884,14 +891,26 @@ def convert_paint_by_example_checkpoint(checkpoint): ...@@ -884,14 +891,26 @@ def convert_paint_by_example_checkpoint(checkpoint):
return model return model
def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."): def convert_open_clip_checkpoint(
checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs
):
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
text_model = CLIPTextModelWithProjection.from_pretrained( # text_model = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
) # )
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
with init_empty_weights():
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
keys_to_ignore = []
if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
# make sure to remove all keys > 22
keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
keys_to_ignore += ["cond_stage_model.model.text_projection"]
text_model_dict = {} text_model_dict = {}
if prefix + "text_projection" in checkpoint: if prefix + "text_projection" in checkpoint:
...@@ -902,8 +921,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."): ...@@ -902,8 +921,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
for key in keys: for key in keys:
# if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer if key in keys_to_ignore:
# continue continue
if key[len(prefix) :] in textenc_conversion_map: if key[len(prefix) :] in textenc_conversion_map:
if key.endswith("text_projection"): if key.endswith("text_projection"):
value = checkpoint[key].T value = checkpoint[key].T
...@@ -931,7 +950,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."): ...@@ -931,7 +950,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
text_model_dict[new_key] = checkpoint[key] text_model_dict[new_key] = checkpoint[key]
text_model.load_state_dict(text_model_dict) for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
return text_model return text_model
...@@ -1061,7 +1081,7 @@ def convert_controlnet_checkpoint( ...@@ -1061,7 +1081,7 @@ def convert_controlnet_checkpoint(
def download_from_original_stable_diffusion_ckpt( def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str, checkpoint_path: str,
original_config_file: str = None, original_config_file: str = None,
image_size: int = 512, image_size: Optional[int] = None,
prediction_type: str = None, prediction_type: str = None,
model_type: str = None, model_type: str = None,
extract_ema: bool = False, extract_ema: bool = False,
...@@ -1144,6 +1164,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1144,6 +1164,7 @@ def download_from_original_stable_diffusion_ckpt(
LDMTextToImagePipeline, LDMTextToImagePipeline,
PaintByExamplePipeline, PaintByExamplePipeline,
StableDiffusionControlNetPipeline, StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
...@@ -1166,12 +1187,9 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1166,12 +1187,9 @@ def download_from_original_stable_diffusion_ckpt(
if not is_safetensors_available(): if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1]) raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors import safe_open from safetensors.torch import load_file as safe_load
checkpoint = {} checkpoint = safe_load(checkpoint_path, device="cpu")
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else: else:
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -1183,7 +1201,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1183,7 +1201,7 @@ def download_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint: if "global_step" in checkpoint:
global_step = checkpoint["global_step"] global_step = checkpoint["global_step"]
else: else:
logger.warning("global_step key not found in model") logger.debug("global_step key not found in model")
global_step = None global_step = None
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
...@@ -1230,9 +1248,15 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1230,9 +1248,15 @@ def download_from_original_stable_diffusion_ckpt(
model_type = "SDXL" model_type = "SDXL"
else: else:
model_type = "SDXL-Refiner" model_type = "SDXL-Refiner"
if image_size is None:
image_size = 1024
if num_in_channels is not None: if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels num_in_channels = 9
elif num_in_channels is None:
num_in_channels = 4
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if ( if (
"parameterization" in original_config["model"]["params"] "parameterization" in original_config["model"]["params"]
...@@ -1263,7 +1287,6 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1263,7 +1287,6 @@ def download_from_original_stable_diffusion_ckpt(
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
if model_type in ["SDXL", "SDXL-Refiner"]: if model_type in ["SDXL", "SDXL-Refiner"]:
image_size = 1024
scheduler_dict = { scheduler_dict = {
"beta_schedule": "scaled_linear", "beta_schedule": "scaled_linear",
"beta_start": 0.00085, "beta_start": 0.00085,
...@@ -1279,7 +1302,6 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1279,7 +1302,6 @@ def download_from_original_stable_diffusion_ckpt(
} }
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
scheduler_type = "euler" scheduler_type = "euler"
vae_path = "stabilityai/sdxl-vae"
else: else:
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
...@@ -1318,25 +1340,45 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1318,25 +1340,45 @@ def download_from_original_stable_diffusion_ckpt(
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention unet_config["upcast_attention"] = upcast_attention
unet = UNet2DConditionModel(**unet_config) with init_empty_weights():
unet = UNet2DConditionModel(**unet_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint( converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
) )
unet.load_state_dict(converted_unet_checkpoint)
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
# Convert the VAE model. # Convert the VAE model.
if vae_path is None: if vae_path is None:
vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config) if (
vae.load_state_dict(converted_vae_checkpoint) "model" in original_config
and "params" in original_config.model
and "scale_factor" in original_config.model.params
):
vae_scaling_factor = original_config.model.params.scale_factor
else:
vae_scaling_factor = 0.18215 # default SD scaling factor
vae_config["scaling_factor"] = vae_scaling_factor
with init_empty_weights():
vae = AutoencoderKL(**vae_config)
for param_name, param in converted_vae_checkpoint.items():
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else: else:
vae = AutoencoderKL.from_pretrained(vae_path) vae = AutoencoderKL.from_pretrained(vae_path)
if model_type == "FrozenOpenCLIPEmbedder": if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint) config_name = "stabilityai/stable-diffusion-2"
config_kwargs = {"subfolder": "text_encoder"}
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
if stable_unclip is None: if stable_unclip is None:
...@@ -1469,7 +1511,12 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1469,7 +1511,12 @@ def download_from_original_stable_diffusion_ckpt(
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.")
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
)
pipe = StableDiffusionXLPipeline( pipe = StableDiffusionXLPipeline(
vae=vae, vae=vae,
...@@ -1485,7 +1532,12 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1485,7 +1532,12 @@ def download_from_original_stable_diffusion_ckpt(
tokenizer = None tokenizer = None
text_encoder = None text_encoder = None
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.")
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
)
pipe = StableDiffusionXLImg2ImgPipeline( pipe = StableDiffusionXLImg2ImgPipeline(
vae=vae, vae=vae,
......
...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
...@@ -153,7 +153,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool ...@@ -153,7 +153,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image return mask, masked_image
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-guided image inpainting using Stable Diffusion. Pipeline for text-guided image inpainting using Stable Diffusion.
......
...@@ -20,17 +20,20 @@ import unittest ...@@ -20,17 +20,20 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.models.attention_processor import AttnProcessor
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
...@@ -512,6 +515,42 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase): ...@@ -512,6 +515,42 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
assert np.abs(expected_slice - image_slice).max() < 6e-4 assert np.abs(expected_slice - image_slice).max() < 6e-4
def test_download_local(self):
filename = hf_hub_download("runwayml/stable-diffusion-inpainting", filename="sd-v1-5-inpainting.ckpt")
pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 1
image_out = pipe(**inputs).images[0]
assert image_out.shape == (512, 512, 3)
def test_download_ckpt_diff_format_is_same(self):
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 5
image_ckpt = pipe(**inputs).images[0]
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 5
image = pipe(**inputs).images[0]
assert np.max(np.abs(image - image_ckpt)) < 1e-4
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
...@@ -29,6 +30,7 @@ from diffusers import ( ...@@ -29,6 +30,7 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils import load_numpy, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
...@@ -426,6 +428,40 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase): ...@@ -426,6 +428,40 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
assert image.shape == (768, 768, 3) assert image.shape == (768, 768, 3)
assert np.abs(expected_image - image).max() < 7.5e-1 assert np.abs(expected_image - image).max() < 7.5e-1
def test_download_local(self):
filename = hf_hub_download("stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.safetensors")
pipe = StableDiffusionPipeline.from_single_file(filename, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
image_out = pipe("test", num_inference_steps=1, output_type="np").images[0]
assert image_out.shape == (768, 768, 3)
def test_download_ckpt_diff_format_is_same(self):
single_file_path = (
"https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
)
pipe_single = StableDiffusionPipeline.from_single_file(single_file_path)
pipe_single.scheduler = DDIMScheduler.from_config(pipe_single.scheduler.config)
pipe_single.unet.set_attn_processor(AttnProcessor())
pipe_single.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(0)
image_ckpt = pipe_single("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0]
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0]
assert np.max(np.abs(image - image_ckpt)) < 1e-3
def test_stable_diffusion_text2img_intermediate_state_v_pred(self): def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
number_of_steps = 0 number_of_steps = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment