Unverified Commit 6b1abba1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add controlnet and vae from single file (#4084)



* Add controlnet from single file

* Updates

* make style

* finish

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 470f51cd
...@@ -35,3 +35,11 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio ...@@ -35,3 +35,11 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio
## FromSingleFileMixin ## FromSingleFileMixin
[[autodoc]] loaders.FromSingleFileMixin [[autodoc]] loaders.FromSingleFileMixin
## FromOriginalControlnetMixin
[[autodoc]] loaders.FromOriginalControlnetMixin
## FromOriginalVAEMixin
[[autodoc]] loaders.FromOriginalVAEMixin
...@@ -6,6 +6,18 @@ The abstract from the paper is: ...@@ -6,6 +6,18 @@ The abstract from the paper is:
*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.* *How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*
## Loading from the original format
By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
```py
from diffusers import AutoencoderKL
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```
## AutoencoderKL ## AutoencoderKL
[[autodoc]] AutoencoderKL [[autodoc]] AutoencoderKL
...@@ -28,4 +40,4 @@ The abstract from the paper is: ...@@ -28,4 +40,4 @@ The abstract from the paper is:
## FlaxDecoderOutput ## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput [[autodoc]] models.vae_flax.FlaxDecoderOutput
\ No newline at end of file
...@@ -6,6 +6,21 @@ The abstract from the paper is: ...@@ -6,6 +6,21 @@ The abstract from the paper is:
*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.* *We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*
## Loading from the original format
By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
```py
from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
controlnet = ControlNetModel.from_single_file(url)
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
```
## ControlNetModel ## ControlNetModel
[[autodoc]] ControlNetModel [[autodoc]] ControlNetModel
...@@ -20,4 +35,4 @@ The abstract from the paper is: ...@@ -20,4 +35,4 @@ The abstract from the paper is:
## FlaxControlNetOutput ## FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput [[autodoc]] models.controlnet_flax.FlaxControlNetOutput
\ No newline at end of file
This diff is collapsed.
...@@ -18,6 +18,7 @@ import torch ...@@ -18,6 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import BaseOutput, apply_forward_hook from ..utils import BaseOutput, apply_forward_hook
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -38,7 +39,7 @@ class AutoencoderKLOutput(BaseOutput): ...@@ -38,7 +39,7 @@ class AutoencoderKLOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution" latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
......
...@@ -19,6 +19,7 @@ from torch import nn ...@@ -19,6 +19,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlnetMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
...@@ -100,7 +101,7 @@ class ControlNetConditioningEmbedding(nn.Module): ...@@ -100,7 +101,7 @@ class ControlNetConditioningEmbedding(nn.Module):
return embedding return embedding
class ControlNetModel(ModelMixin, ConfigMixin): class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
""" """
A ControlNet model. A ControlNet model.
......
...@@ -24,7 +24,7 @@ import torch.nn.functional as F ...@@ -24,7 +24,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -90,7 +90,9 @@ EXAMPLE_DOC_STRING = """ ...@@ -90,7 +90,9 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionControlNetPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
......
...@@ -24,7 +24,7 @@ import torch.nn.functional as F ...@@ -24,7 +24,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -116,7 +116,9 @@ def prepare_image(image): ...@@ -116,7 +116,9 @@ def prepare_image(image):
return image return image
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionControlNetImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
......
...@@ -25,7 +25,7 @@ import torch.nn.functional as F ...@@ -25,7 +25,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -222,7 +222,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False ...@@ -222,7 +222,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
return mask, masked_image return mask, masked_image
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionControlNetInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
......
...@@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint( ...@@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint(
def convert_ldm_vae_checkpoint(checkpoint, config): def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE # extract state dict for VAE
vae_state_dict = {} vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
for key in keys: for key in keys:
if key.startswith(vae_key): if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
...@@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint( ...@@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint(
if cross_attention_dim is not None: if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim ctrlnet_config["cross_attention_dim"] = cross_attention_dim
controlnet_model = ControlNetModel(**ctrlnet_config) controlnet = ControlNetModel(**ctrlnet_config)
# Some controlnet ckpt files are distributed independently from the rest of the # Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
...@@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint( ...@@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint(
skip_extract_state_dict=skip_extract_state_dict, skip_extract_state_dict=skip_extract_state_dict,
) )
controlnet_model.load_state_dict(converted_ctrl_checkpoint) controlnet.load_state_dict(converted_ctrl_checkpoint)
return controlnet_model return controlnet
def download_from_original_stable_diffusion_ckpt( def download_from_original_stable_diffusion_ckpt(
...@@ -1181,7 +1181,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1181,7 +1181,7 @@ def download_from_original_stable_diffusion_ckpt(
) )
if pipeline_class is None: if pipeline_class is None:
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
if prediction_type == "v-prediction": if prediction_type == "v-prediction":
prediction_type = "v_prediction" prediction_type = "v_prediction"
...@@ -1288,8 +1288,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1288,8 +1288,7 @@ def download_from_original_stable_diffusion_ckpt(
if controlnet is None: if controlnet is None:
controlnet = "control_stage_config" in original_config.model.params controlnet = "control_stage_config" in original_config.model.params
if controlnet: controlnet = convert_controlnet_checkpoint(
controlnet_model = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
) )
...@@ -1400,13 +1399,13 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1400,13 +1399,13 @@ def download_from_original_stable_diffusion_ckpt(
if stable_unclip is None: if stable_unclip is None:
if controlnet: if controlnet:
pipe = StableDiffusionControlNetPipeline( pipe = pipeline_class(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
controlnet=controlnet_model, controlnet=controlnet,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
...@@ -1503,12 +1502,12 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1503,12 +1502,12 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor = None feature_extractor = None
if controlnet: if controlnet:
pipe = StableDiffusionControlNetPipeline( pipe = pipeline_class(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet,
controlnet=controlnet_model, controlnet=controlnet,
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
...@@ -1623,7 +1622,7 @@ def download_controlnet_from_original_ckpt( ...@@ -1623,7 +1622,7 @@ def download_controlnet_from_original_ckpt(
if "control_stage_config" not in original_config.model.params: if "control_stage_config" not in original_config.model.params:
raise ValueError("`control_stage_config` not present in original config") raise ValueError("`control_stage_config` not present in original config")
controlnet_model = convert_controlnet_checkpoint( controlnet = convert_controlnet_checkpoint(
checkpoint, checkpoint,
original_config, original_config,
checkpoint_path, checkpoint_path,
...@@ -1634,4 +1633,4 @@ def download_controlnet_from_original_ckpt( ...@@ -1634,4 +1633,4 @@ def download_controlnet_from_original_ckpt(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
return controlnet_model return controlnet
...@@ -199,7 +199,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -199,7 +199,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
revision=revision, revision=revision,
) )
model.to(torch_device).eval() model.to(torch_device)
return model return model
...@@ -383,3 +383,22 @@ class AutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -383,3 +383,22 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
tolerance = 3e-3 if torch_device != "mps" else 1e-2 tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance) assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
def test_stable_diffusion_model_local(self):
model_id = "stabilityai/sd-vae-ft-mse"
model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
image = self.get_sd_image(33)
with torch.no_grad():
sample_1 = model_1(image).sample
sample_2 = model_2(image).sample
assert sample_1.shape == sample_2.shape
output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
...@@ -752,6 +752,42 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -752,6 +752,42 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348]) expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
)
pipes = [pipe_1, pipe_2]
images = []
for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
images.append(output.images[0])
del pipe
gc.collect()
torch.cuda.empty_cache()
assert np.abs(images[0] - images[1]).sum() < 1e-3
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -401,3 +401,49 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): ...@@ -401,3 +401,49 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
) )
assert np.abs(expected_image - image).max() < 9e-2 assert np.abs(expected_image - image).max() < 9e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
)
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))
pipes = [pipe_1, pipe_2]
images = []
for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
output = pipe(
prompt,
image=image,
control_image=control_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
)
images.append(output.images[0])
del pipe
gc.collect()
torch.cuda.empty_cache()
assert np.abs(images[0] - images[1]).sum() < 1e-3
...@@ -543,3 +543,54 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): ...@@ -543,3 +543,54 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
) )
assert np.abs(expected_image - image).max() < 9e-2 assert np.abs(expected_image - image).max() < 9e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetInpaintPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
)
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))
mask_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
).resize((512, 512))
pipes = [pipe_1, pipe_2]
images = []
for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
output = pipe(
prompt,
image=image,
control_image=control_image,
mask_image=mask_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
)
images.append(output.images[0])
del pipe
gc.collect()
torch.cuda.empty_cache()
assert np.abs(images[0] - images[1]).sum() < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment