Unverified Commit 872ae1dd authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Add from single file to StableDiffusionUpscalePipeline and...

Add from single file to StableDiffusionUpscalePipeline and StableDiffusionLatentUpscalePipeline (#5194)

* add from single file

* clean up

* make style

* add single file loading for upscaling
parent 6ce01bd6
...@@ -304,8 +304,6 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -304,8 +304,6 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
class_embed_type = "projection" class_embed_type = "projection"
assert "adm_in_channels" in unet_params assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels projection_class_embeddings_input_dim = unet_params.adm_in_channels
else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
config = { config = {
"sample_size": image_size // vae_scale_factor, "sample_size": image_size // vae_scale_factor,
...@@ -323,6 +321,12 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -323,6 +321,12 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
"transformer_layers_per_block": transformer_layers_per_block, "transformer_layers_per_block": transformer_layers_per_block,
} }
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
if "num_classes" in unet_params and type(unet_params.num_classes) == int:
config["num_class_embeds"] = unet_params.num_classes
if controlnet: if controlnet:
config["conditioning_channels"] = unet_params.hint_channels config["conditioning_channels"] = unet_params.hint_channels
else: else:
...@@ -441,6 +445,10 @@ def convert_ldm_unet_checkpoint( ...@@ -441,6 +445,10 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
# Relevant to StableDiffusionUpscalePipeline
if "num_class_embeds" in config:
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
...@@ -496,6 +504,7 @@ def convert_ldm_unet_checkpoint( ...@@ -496,6 +504,7 @@ def convert_ldm_unet_checkpoint(
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint( assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
...@@ -1210,6 +1219,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1210,6 +1219,7 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionControlNetPipeline, StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableUnCLIPImg2ImgPipeline, StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline, StableUnCLIPPipeline,
...@@ -1256,6 +1266,8 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1256,6 +1266,8 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
is_upscale = pipeline_class == StableDiffusionUpscalePipeline
config_url = None config_url = None
# model_type = "v1" # model_type = "v1"
...@@ -1285,6 +1297,10 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1285,6 +1297,10 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = config_files["xl_refiner"] original_config_file = config_files["xl_refiner"]
else: else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
if is_upscale:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
if config_url is not None: if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content) original_config_file = BytesIO(requests.get(config_url).content)
...@@ -1308,6 +1324,8 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1308,6 +1324,8 @@ def download_from_original_stable_diffusion_ckpt(
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9 num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
num_in_channels = 7
elif num_in_channels is None: elif num_in_channels is None:
num_in_channels = 4 num_in_channels = 4
...@@ -1391,9 +1409,13 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1391,9 +1409,13 @@ def download_from_original_stable_diffusion_ckpt(
else: else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
if pipeline_class == StableDiffusionUpscalePipeline:
image_size = original_config.model.params.unet_config.params.image_size
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention unet_config["upcast_attention"] = upcast_attention
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
converted_unet_checkpoint = convert_ldm_unet_checkpoint( converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=path, extract_ema=extract_ema checkpoint, unet_config, path=path, extract_ema=extract_ema
...@@ -1458,8 +1480,29 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1458,8 +1480,29 @@ def download_from_original_stable_diffusion_ckpt(
controlnet=controlnet, controlnet=controlnet,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False,
) )
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False
elif pipeline_class == StableDiffusionUpscalePipeline:
scheduler = DDIMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
)
low_res_scheduler = DDPMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
)
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
low_res_scheduler=low_res_scheduler,
safety_checker=None,
feature_extractor=None,
)
else: else:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae,
...@@ -1469,8 +1512,10 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1469,8 +1512,10 @@ def download_from_original_stable_diffusion_ckpt(
scheduler=scheduler, scheduler=scheduler,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False,
) )
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False
else: else:
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config, clip_stats_path=clip_stats_path, device=device original_config, clip_stats_path=clip_stats_path, device=device
......
...@@ -22,6 +22,7 @@ import torch.nn.functional as F ...@@ -22,6 +22,7 @@ import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler from ...schedulers import EulerDiscreteScheduler
from ...utils import deprecate, logging from ...utils import deprecate, logging
...@@ -59,7 +60,7 @@ def preprocess(image): ...@@ -59,7 +60,7 @@ def preprocess(image):
return image return image
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixin):
r""" r"""
Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2. Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2.
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
...@@ -67,7 +67,9 @@ def preprocess(image): ...@@ -67,7 +67,9 @@ def preprocess(image):
return image return image
class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionUpscalePipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2. Pipeline for text-guided image super-resolution using Stable Diffusion 2.
......
...@@ -29,6 +29,7 @@ from diffusers.utils.testing_utils import ( ...@@ -29,6 +29,7 @@ from diffusers.utils.testing_utils import (
floats_tensor, floats_tensor,
load_image, load_image,
load_numpy, load_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu, require_torch_gpu,
slow, slow,
torch_device, torch_device,
...@@ -479,3 +480,36 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase): ...@@ -479,3 +480,36 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
mem_bytes = torch.cuda.max_memory_allocated() mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.9 GB is allocated # make sure that less than 2.9 GB is allocated
assert mem_bytes < 2.9 * 10**9 assert mem_bytes < 2.9 * 10**9
def test_download_ckpt_diff_format_is_same(self):
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/sd2-upscale/low_res_cat.png"
)
prompt = "a cat sitting on a park bench"
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
pipe.enable_model_cpu_offload()
generator = torch.Generator("cpu").manual_seed(0)
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
image_from_pretrained = output.images[0]
single_file_path = (
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
)
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(single_file_path)
pipe_from_single_file.enable_model_cpu_offload()
generator = torch.Generator("cpu").manual_seed(0)
output_from_single_file = pipe_from_single_file(
prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
)
image_from_single_file = output_from_single_file.images[0]
assert image_from_pretrained.shape == (512, 512, 3)
assert image_from_single_file.shape == (512, 512, 3)
assert (
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment