Unverified Commit 872ae1dd authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Add from single file to StableDiffusionUpscalePipeline and...

Add from single file to StableDiffusionUpscalePipeline and StableDiffusionLatentUpscalePipeline (#5194)

* add from single file

* clean up

* make style

* add single file loading for upscaling
parent 6ce01bd6
......@@ -304,8 +304,6 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
config = {
"sample_size": image_size // vae_scale_factor,
......@@ -323,6 +321,12 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
"transformer_layers_per_block": transformer_layers_per_block,
}
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
if "num_classes" in unet_params and type(unet_params.num_classes) == int:
config["num_class_embeds"] = unet_params.num_classes
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
......@@ -441,6 +445,10 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
# Relevant to StableDiffusionUpscalePipeline
if "num_class_embeds" in config:
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
......@@ -496,6 +504,7 @@ def convert_ldm_unet_checkpoint(
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
......@@ -1210,6 +1219,7 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLImg2ImgPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
......@@ -1256,6 +1266,8 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
is_upscale = pipeline_class == StableDiffusionUpscalePipeline
config_url = None
# model_type = "v1"
......@@ -1285,6 +1297,10 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = config_files["xl_refiner"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
if is_upscale:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
......@@ -1308,6 +1324,8 @@ def download_from_original_stable_diffusion_ckpt(
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
num_in_channels = 7
elif num_in_channels is None:
num_in_channels = 4
......@@ -1391,9 +1409,13 @@ def download_from_original_stable_diffusion_ckpt(
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
if pipeline_class == StableDiffusionUpscalePipeline:
image_size = original_config.model.params.unet_config.params.image_size
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=path, extract_ema=extract_ema
......@@ -1458,8 +1480,29 @@ def download_from_original_stable_diffusion_ckpt(
controlnet=controlnet,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False
elif pipeline_class == StableDiffusionUpscalePipeline:
scheduler = DDIMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
)
low_res_scheduler = DDPMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
)
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
low_res_scheduler=low_res_scheduler,
safety_checker=None,
feature_extractor=None,
)
else:
pipe = pipeline_class(
vae=vae,
......@@ -1469,8 +1512,10 @@ def download_from_original_stable_diffusion_ckpt(
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False
else:
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config, clip_stats_path=clip_stats_path, device=device
......
......@@ -22,6 +22,7 @@ import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import deprecate, logging
......@@ -59,7 +60,7 @@ def preprocess(image):
return image
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2.
......
......@@ -22,7 +22,7 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
......@@ -67,7 +67,9 @@ def preprocess(image):
return image
class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionUpscalePipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
......
......@@ -29,6 +29,7 @@ from diffusers.utils.testing_utils import (
floats_tensor,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
......@@ -479,3 +480,36 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.9 GB is allocated
assert mem_bytes < 2.9 * 10**9
def test_download_ckpt_diff_format_is_same(self):
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/sd2-upscale/low_res_cat.png"
)
prompt = "a cat sitting on a park bench"
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
pipe.enable_model_cpu_offload()
generator = torch.Generator("cpu").manual_seed(0)
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
image_from_pretrained = output.images[0]
single_file_path = (
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
)
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(single_file_path)
pipe_from_single_file.enable_model_cpu_offload()
generator = torch.Generator("cpu").manual_seed(0)
output_from_single_file = pipe_from_single_file(
prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
)
image_from_single_file = output_from_single_file.images[0]
assert image_from_pretrained.shape == (512, 512, 3)
assert image_from_single_file.shape == (512, 512, 3)
assert (
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment