Commit e4ffadc4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents ec7c8d32 c9b34637
...@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. ...@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
# Stable Diffusion text-to-image fine-tuning # Stable Diffusion text-to-image fine-tuning
The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) script shows how to fine-tune the stable diffusion model on your own dataset. The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) script shows how to fine-tune the stable diffusion model on your own dataset.
<Tip warning={true}> <Tip warning={true}>
......
...@@ -557,6 +557,9 @@ class CrossAttention(nn.Module): ...@@ -557,6 +557,9 @@ class CrossAttention(nn.Module):
return hidden_states return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value): def _memory_efficient_attention_xformers(self, query, key, value):
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states return hidden_states
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -178,6 +179,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -178,6 +179,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -197,14 +199,33 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -197,14 +199,33 @@ class CycleDiffusionPipeline(DiffusionPipeline):
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self): def disable_attention_slicing(self):
r""" r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step. back to computing attention in one step.
""" """
# set slice_size = `None` to disable `set_attention_slice` # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self): def _execution_device(self):
...@@ -224,6 +245,26 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -224,6 +245,26 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r""" r"""
...@@ -310,6 +351,106 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -310,6 +351,106 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return text_embeddings return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(self, prompt, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# add noise to latents using the timestep
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
# get latents
clean_latents = init_latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents, clean_latents
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -384,112 +525,43 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -384,112 +525,43 @@ class CycleDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if isinstance(prompt, str): # 1. Check inputs
batch_size = 1 self.check_inputs(prompt, strength, callback_steps)
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if batch_size != 1:
raise ValueError(
"At the moment only `batch_size=1` is supported for prompts, but you seem to have passed multiple"
f" prompts: {prompt}. Please make sure to pass only a single prompt."
)
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None) text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None)
source_text_embeddings = self._encode_prompt( source_text_embeddings = self._encode_prompt(
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
) )
# encode the init image into latents and scale the latents # 4. Preprocess image
latents_dtype = text_embeddings.dtype if isinstance(init_image, PIL.Image.Image):
init_image = init_image.to(device=self.device, dtype=latents_dtype) init_image = preprocess(init_image)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if isinstance(prompt, str):
prompt = [prompt]
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
clean_latents = init_latents
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if not (accepts_eta and (0 < eta <= 1)):
raise ValueError(
"Currently, only the DDIM scheduler is supported. Please make sure that `pipeline.scheduler` is of"
f" type {DDIMScheduler.__class__} and not {self.scheduler.__class__}."
)
extra_step_kwargs["eta"] = eta # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents = init_latents # 6. Prepare latent variables
source_latents = init_latents latents, clean_latents = self.prepare_latents(
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
t_start = max(num_inference_steps - init_timestep + offset, 0) )
source_latents = latents
# Some schedulers like PNDM have timesteps as arrays # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# It's more optimized to move all timesteps to correct device beforehand extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
timesteps = self.scheduler.timesteps[t_start:].to(self.device) generator = extra_step_kwargs.pop("generator", None)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) latent_model_input = torch.cat([latents] * 2)
...@@ -551,22 +623,13 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -551,22 +623,13 @@ class CycleDiffusionPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
latents = 1 / 0.18215 * latents # 9. Post-processing
image = self.vae.decode(latents).sample image = self.decode_latents(latents)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if self.safety_checker is not None: # 10. Run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
# 11. Convert to PIL
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -298,6 +298,73 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -298,6 +298,73 @@ class StableDiffusionPipeline(DiffusionPipeline):
return text_embeddings return text_embeddings
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -371,75 +438,45 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -371,75 +438,45 @@ class StableDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0: # 1. Check inputs. Raise error if not correct
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") self.check_inputs(prompt, height, width, callback_steps)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt( text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
# Unlike in other pipelines, latents need to be generated in the target device # 4. Prepare timesteps
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
# get the initial random noise unless the user supplied it
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device)
else:
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(device)
# set timesteps and move to the correct device
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler # 5. Prepare latent variables
latents = latents * self.scheduler.init_noise_sigma num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature batch_size * num_images_per_prompt,
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. num_channels_latents,
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 height,
# and should be between [0, 1] width,
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) text_embeddings.dtype,
extra_step_kwargs = {} device,
if accepts_eta: generator,
extra_step_kwargs["eta"] = eta latents,
)
# check if the scheduler accepts generator # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
if accepts_generator:
extra_step_kwargs["generator"] = generator
for i, t in enumerate(self.progress_bar(timesteps_tensor)): # 7. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -459,22 +496,13 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -459,22 +496,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
latents = 1 / 0.18215 * latents # 8. Post-processing
image = self.vae.decode(latents).sample image = self.decode_latents(latents)
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if self.safety_checker is not None: # 9. Run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
# 10. Convert to PIL
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel ...@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
...@@ -78,6 +79,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -78,6 +79,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__( def __init__(
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
...@@ -85,7 +87,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -85,7 +87,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[ scheduler: Union[
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
], ],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
...@@ -139,6 +146,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -139,6 +146,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -158,14 +166,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -158,14 +166,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self): def disable_attention_slicing(self):
r""" r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step. back to computing attention in one step.
""" """
# set slice_size = `None` to disable `set_attention_slice` # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self): def enable_sequential_cpu_offload(self):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
...@@ -202,6 +212,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -202,6 +212,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self):
r""" r"""
Enable memory efficient attention as implemented in xformers. Enable memory efficient attention as implemented in xformers.
...@@ -214,6 +225,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -214,6 +225,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
""" """
self.unet.set_use_memory_efficient_attention_xformers(True) self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self): def disable_xformers_memory_efficient_attention(self):
r""" r"""
Disable memory efficient attention as implemented in xformers. Disable memory efficient attention as implemented in xformers.
...@@ -306,6 +318,103 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -306,6 +318,103 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return text_embeddings return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -379,102 +488,40 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -379,102 +488,40 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if isinstance(prompt, str): # 1. Check inputs
batch_size = 1 self.check_inputs(prompt, strength, callback_steps)
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device device = self._execution_device
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt( text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
# encode the init image into latents and scale the latents # 4. Preprocess image
latents_dtype = text_embeddings.dtype if isinstance(init_image, PIL.Image.Image):
init_image = init_image.to(device=device, dtype=latents_dtype) init_image = preprocess(init_image)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if isinstance(prompt, str):
prompt = [prompt]
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
latents = init_latents # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
t_start = max(num_inference_steps - init_timestep + offset, 0) # 6. Prepare latent variables
latents = self.prepare_latents(
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
)
# Some schedulers like PNDM have timesteps as arrays # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# It's more optimized to move all timesteps to correct device beforehand extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
timesteps = self.scheduler.timesteps[t_start:].to(device)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
...@@ -495,20 +542,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -495,20 +542,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
latents = 1 / 0.18215 * latents # 9. Post-processing
image = self.vae.decode(latents).sample image = self.decode_latents(latents)
image = (image / 2 + 0.5).clamp(0, 1) # 10. Run safety checker
image = image.cpu().permute(0, 2, 3, 1).numpy() image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
# 11. Convert to PIL
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -139,6 +139,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -139,6 +139,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -158,6 +159,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -158,6 +159,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self): def disable_attention_slicing(self):
r""" r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
...@@ -166,6 +168,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -166,6 +168,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing` # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self): def enable_sequential_cpu_offload(self):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
...@@ -183,6 +186,26 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -183,6 +186,26 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if cpu_offloaded_model is not None: if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self): def _execution_device(self):
...@@ -202,24 +225,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -202,24 +225,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r""" r"""
...@@ -306,6 +311,106 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -306,6 +311,106 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return text_embeddings return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = 0.18215 * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
mask = mask.repeat(batch_size, 1, 1, 1)
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -390,83 +495,59 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -390,83 +495,59 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if isinstance(prompt, str): # 1. Check inputs
batch_size = 1 self.check_inputs(prompt, height, width, callback_steps)
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt( text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
# get the initial random noise unless the user supplied it # 4. Preprocess mask and image
# Unlike in other pipelines, latents need to be generated in the target device if isinstance(image, PIL.Image.Image) and isinstance(mask_image, PIL.Image.Image):
# for 1-to-1 results reproducibility with the CompVis implementation. mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
# However this currently doesn't work in `mps`.
num_channels_latents = self.vae.config.latent_channels
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if device.type == "mps":
# randn does not exist on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device)
else:
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(device)
# prepare mask and masked_image
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
# resize the mask to latents shape as we concatenate the mask to the latents # 5. set timesteps
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload self.scheduler.set_timesteps(num_inference_steps, device=device)
# and half precision timesteps_tensor = self.scheduler.timesteps
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
mask = mask.to(device=device, dtype=text_embeddings.dtype)
masked_image = masked_image.to(device=device, dtype=text_embeddings.dtype)
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = 0.18215 * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask # 6. Prepare latent variables
masked_image_latents = ( num_channels_latents = self.vae.config.latent_channels
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
) )
# aligning device to prevent device errors when concating it with the latent model input # 7. Prepare mask latent variables
masked_image_latents = masked_image_latents.to(device=device, dtype=text_embeddings.dtype) mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size,
height,
width,
text_embeddings.dtype,
device,
generator,
do_classifier_free_guidance,
)
# 8. Check that sizes of mask, masked image and latents match
num_channels_mask = mask.shape[1] num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1] num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
raise ValueError( raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
...@@ -476,27 +557,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -476,27 +557,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
# set timesteps and move to the correct device # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
self.scheduler.set_timesteps(num_inference_steps, device=device) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
timesteps_tensor = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
# 10. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps_tensor)): for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
...@@ -521,22 +585,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -521,22 +585,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
latents = 1 / 0.18215 * latents # 11. Post-processing
image = self.vae.decode(latents).sample image = self.decode_latents(latents)
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if self.safety_checker is not None: # 12. Run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
# 13. Convert to PIL
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -19,13 +19,20 @@ import numpy as np ...@@ -19,13 +19,20 @@ import numpy as np
import torch import torch
import PIL import PIL
from tqdm.auto import tqdm from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -85,17 +92,26 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -85,17 +92,26 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__( def __init__(
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = ( deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
...@@ -143,6 +159,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -143,6 +159,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -162,14 +179,53 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -162,14 +179,53 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self): def disable_attention_slicing(self):
r""" r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step. back to computing attention in one step.
""" """
# set slice_size = `None` to disable `set_attention_slice` # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self): def _execution_device(self):
...@@ -275,6 +331,88 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -275,6 +331,88 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return text_embeddings return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(self, prompt, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
init_image = init_image.to(device=self.device, dtype=dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents, init_latents_orig, noise
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -353,98 +491,49 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -353,98 +491,49 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if isinstance(prompt, str): # 1. Check inputs
batch_size = 1 self.check_inputs(prompt, strength, callback_steps)
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device device = self._execution_device
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt( text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
# encode the init image into latents and scale the latents # 4. Preprocess image and mask
latents_dtype = text_embeddings.dtype if not isinstance(init_image, torch.FloatTensor):
init_image = init_image.to(device=self.device, dtype=latents_dtype) init_image = preprocess_image(init_image)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# preprocess mask
if not isinstance(mask_image, torch.FloatTensor): if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image) mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
# check sizes # 5. set timesteps
if not mask.shape == init_latents.shape: self.scheduler.set_timesteps(num_inference_steps, device=device)
raise ValueError("The mask and init_image should be the same size!") timesteps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# get the original timestep using init_timestep # 6. Prepare latent variables
offset = self.scheduler.config.get("steps_offset", 0) # encode the init image into latents and scale the latents
init_timestep = int(num_inference_steps * strength) + offset latents, init_latents_orig, noise = self.prepare_latents(
init_timestep = min(init_timestep, num_inference_steps) init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0) # 7. Prepare mask latent
mask = mask_image.to(device=self.device, dtype=latents.dtype)
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
# Some schedulers like PNDM have timesteps as arrays # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# It's more optimized to move all timesteps to correct device beforehand extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in tqdm(enumerate(timesteps)): # 9. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...@@ -468,22 +557,13 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -468,22 +557,13 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
latents = 1 / 0.18215 * latents # 10. Post-processing
image = self.vae.decode(latents).sample image = self.decode_latents(latents)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if self.safety_checker is not None: # 11. Run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
# 12. Convert to PIL
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
...@@ -479,7 +480,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -479,7 +480,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
) )
init_image = init_image.resize((768, 512)) init_image = init_image.resize((768, 512))
expected_image = load_numpy( expected_image = load_numpy(
"https://huggingface.co/datasets/lewington/expected-images/resolve/main/fantasy_landscape.npy" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.npy"
) )
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
...@@ -506,7 +507,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -506,7 +507,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
assert image.shape == (512, 768, 3) assert image.shape == (512, 768, 3)
# img2img is flaky across GPUs even in fp32, so using MAE here # img2img is flaky across GPUs even in fp32, so using MAE here
assert np.abs(expected_image - image).mean() < 1e-3 assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_img2img_pipeline_k_lms(self): def test_stable_diffusion_img2img_pipeline_k_lms(self):
init_image = load_image( init_image = load_image(
...@@ -515,7 +516,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -515,7 +516,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
) )
init_image = init_image.resize((768, 512)) init_image = init_image.resize((768, 512))
expected_image = load_numpy( expected_image = load_numpy(
"https://huggingface.co/datasets/lewington/expected-images/resolve/main/fantasy_landscape_k_lms.npy" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_k_lms.npy"
) )
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
...@@ -543,8 +544,44 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ...@@ -543,8 +544,44 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
image = output.images[0] image = output.images[0]
assert image.shape == (512, 768, 3) assert image.shape == (512, 768, 3)
# img2img is flaky across GPUs even in fp32, so using MAE here assert np.abs(expected_image - image).max() < 1e-3
assert np.abs(expected_image - image).mean() < 1e-3
def test_stable_diffusion_img2img_pipeline_ddim(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_ddim.npy"
)
model_id = "CompVis/stable-diffusion-v1-4"
ddim = DDIMScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=ddim,
safety_checker=None,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
init_image=init_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (512, 768, 3)
assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_img2img_intermediate_state(self): def test_stable_diffusion_img2img_intermediate_state(self):
number_of_steps = 0 number_of_steps = 0
......
...@@ -387,7 +387,6 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): ...@@ -387,7 +387,6 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
assert np.abs(expected_image - image).max() < 1e-3 assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_inpaint_legacy_pipeline_k_lms(self): def test_stable_diffusion_inpaint_legacy_pipeline_k_lms(self):
# TODO(Anton, Patrick) - I think we can remove this test soon
init_image = load_image( init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png" "/in_paint/overture-creations-5sI6fQgYIuo.png"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment