Unverified Commit 5c7a35a2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Torch 2.0 compile] Fix more torch compile breaks (#3313)



* Fix more torch compile breaks

* add tests

* Fix all

* fix controlnet

* fix more

* Add Horace He as co-author.
>
>
Co-authored-by: default avatarHorace He <horacehe2007@yahoo.com>

* Add Horace He as co-author.
Co-authored-by: default avatarHorace He <horacehe2007@yahoo.com>

---------
Co-authored-by: default avatarHorace He <horacehe2007@yahoo.com>
parent a7f25b4a
...@@ -498,7 +498,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -498,7 +498,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# timesteps does not contain any weights and will always return f32 tensors # timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here. # but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this. # there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
...@@ -517,7 +517,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -517,7 +517,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample += controlnet_cond sample = sample + controlnet_cond
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
...@@ -551,7 +551,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -551,7 +551,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample) down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples down_block_res_samples = controlnet_down_block_res_samples
...@@ -559,13 +559,14 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -559,13 +559,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# 6. scaling # 6. scaling
if guess_mode: if guess_mode:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0 scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales *= conditioning_scale
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample *= scales[-1] # last one mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else: else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions: if self.config.global_pool_conditions:
down_block_res_samples = [ down_block_res_samples = [
......
...@@ -740,7 +740,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -740,7 +740,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_res_samples, down_block_additional_residuals down_block_res_samples, down_block_additional_residuals
): ):
down_block_res_sample = down_block_res_sample + down_block_additional_residual down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,) new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples down_block_res_samples = new_down_block_res_samples
......
...@@ -457,7 +457,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -457,7 +457,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
FutureWarning, FutureWarning,
) )
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
...@@ -728,7 +728,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -728,7 +728,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -736,7 +737,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -736,7 +737,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
...@@ -745,7 +746,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -745,7 +746,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
callback(i, t, latents) callback(i, t, latents)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else: else:
image = latents image = latents
......
...@@ -918,7 +918,8 @@ class IFImg2ImgPipeline(DiffusionPipeline): ...@@ -918,7 +918,8 @@ class IFImg2ImgPipeline(DiffusionPipeline):
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -930,8 +931,8 @@ class IFImg2ImgPipeline(DiffusionPipeline): ...@@ -930,8 +931,8 @@ class IFImg2ImgPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
).prev_sample )[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -1036,7 +1036,8 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline): ...@@ -1036,7 +1036,8 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
class_labels=noise_level, class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -1048,8 +1049,8 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline): ...@@ -1048,8 +1049,8 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
).prev_sample )[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -1033,7 +1033,8 @@ class IFInpaintingPipeline(DiffusionPipeline): ...@@ -1033,7 +1033,8 @@ class IFInpaintingPipeline(DiffusionPipeline):
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -1047,8 +1048,8 @@ class IFInpaintingPipeline(DiffusionPipeline): ...@@ -1047,8 +1048,8 @@ class IFInpaintingPipeline(DiffusionPipeline):
prev_intermediate_images = intermediate_images prev_intermediate_images = intermediate_images
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
).prev_sample )[0]
intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images
......
...@@ -1143,7 +1143,8 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline): ...@@ -1143,7 +1143,8 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
class_labels=noise_level, class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -1157,8 +1158,8 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline): ...@@ -1157,8 +1158,8 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
prev_intermediate_images = intermediate_images prev_intermediate_images = intermediate_images
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
).prev_sample )[0]
intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images
......
...@@ -886,7 +886,8 @@ class IFSuperResolutionPipeline(DiffusionPipeline): ...@@ -886,7 +886,8 @@ class IFSuperResolutionPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
class_labels=noise_level, class_labels=noise_level,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -898,8 +899,8 @@ class IFSuperResolutionPipeline(DiffusionPipeline): ...@@ -898,8 +899,8 @@ class IFSuperResolutionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step( intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
).prev_sample )[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
...@@ -579,9 +580,20 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -579,9 +580,20 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
) )
# Check `image` # Check `image`
if isinstance(self.controlnet, ControlNetModel): is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
self.check_image(image, prompt, prompt_embeds) self.check_image(image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel): elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(image, list): if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`") raise TypeError("For multiple controlnets: `image` must be type `list`")
...@@ -600,10 +612,18 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -600,10 +612,18 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
assert False assert False
# Check `controlnet_conditioning_scale` # Check `controlnet_conditioning_scale`
if isinstance(self.controlnet, ControlNetModel): if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if not isinstance(controlnet_conditioning_scale, float): if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel): elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if isinstance(controlnet_conditioning_scale, list): if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale): if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.") raise ValueError("A single batch of multiple conditionings are supported at the moment.")
...@@ -910,7 +930,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -910,7 +930,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
) )
# 4. Prepare image # 4. Prepare image
if isinstance(self.controlnet, ControlNetModel): is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
image = self.prepare_image( image = self.prepare_image(
image=image, image=image,
width=width, width=width,
...@@ -922,7 +949,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -922,7 +949,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
elif isinstance(self.controlnet, MultiControlNetModel): elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
images = [] images = []
for image_ in image: for image_ in image:
...@@ -1006,7 +1037,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -1006,7 +1037,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples, down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample, mid_block_additional_residual=mid_block_res_sample,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -1014,7 +1046,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -1014,7 +1046,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -677,7 +677,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -677,7 +677,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -685,7 +687,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -685,7 +687,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -462,7 +462,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -462,7 +462,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
FutureWarning, FutureWarning,
) )
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
...@@ -734,7 +734,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -734,7 +734,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -742,7 +743,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -742,7 +743,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
...@@ -751,7 +752,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -751,7 +752,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
callback(i, t, latents) callback(i, t, latents)
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else: else:
image = latents image = latents
......
...@@ -878,7 +878,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -878,7 +878,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -886,7 +888,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -886,7 +888,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -690,7 +690,9 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -690,7 +690,9 @@ class StableDiffusionInpaintPipelineLegacy(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -698,7 +700,7 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -698,7 +700,7 @@ class StableDiffusionInpaintPipelineLegacy(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# masking # masking
if add_predicted_noise: if add_predicted_noise:
init_latents_proper = self.scheduler.add_noise( init_latents_proper = self.scheduler.add_noise(
......
...@@ -346,7 +346,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -346,7 +346,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds).sample noise_pred = self.unet(
scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False
)[0]
# Hack: # Hack:
# For karras style schedulers the model does classifer free guidance using the # For karras style schedulers the model does classifer free guidance using the
...@@ -376,7 +378,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -376,7 +378,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
noise_pred = (noise_pred - latents) / (-sigma) noise_pred = (noise_pred - latents) / (-sigma)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -678,8 +678,12 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -678,8 +678,12 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level latent_model_input,
).sample t,
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -687,7 +691,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -687,7 +691,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -830,7 +830,8 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -830,7 +830,8 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
timestep=t, timestep=t,
sample=prior_latents, sample=prior_latents,
**prior_extra_step_kwargs, **prior_extra_step_kwargs,
).prev_sample return_dict=False,
)[0]
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, prior_latents) callback(i, t, prior_latents)
...@@ -903,7 +904,8 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -903,7 +904,8 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
class_labels=image_embeds, class_labels=image_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -911,7 +913,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -911,7 +913,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
......
...@@ -799,7 +799,8 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -799,7 +799,8 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
class_labels=image_embeds, class_labels=image_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -807,7 +808,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -807,7 +808,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
......
...@@ -843,7 +843,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -843,7 +843,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_res_samples, down_block_additional_residuals down_block_res_samples, down_block_additional_residuals
): ):
down_block_res_sample = down_block_res_sample + down_block_additional_residual down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,) new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples down_block_res_samples = new_down_block_res_samples
......
...@@ -866,6 +866,28 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): ...@@ -866,6 +866,28 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
max_diff = np.abs(expected_image - image).max() max_diff = np.abs(expected_image - image).max()
assert max_diff < 5e-2 assert max_diff < 5e-2
def test_stable_diffusion_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.unet.to(memory_format=torch.channels_last)
sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
assert np.abs(image_slice - expected_slice).max() < 5e-3
@slow @slow
@require_torch_gpu @require_torch_gpu
...@@ -922,28 +944,6 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase): ...@@ -922,28 +944,6 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
assert np.max(np.abs(image - image_ckpt)) < 1e-4 assert np.max(np.abs(image - image_ckpt)) < 1e-4
def test_stable_diffusion_compile(self):
if version.parse(torch.__version__) >= version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.unet.to(memory_format=torch.channels_last)
sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
assert np.abs(image_slice - expected_slice).max() < 1e-4
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from packaging import version
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
...@@ -585,6 +586,42 @@ class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase): ...@@ -585,6 +586,42 @@ class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887]) expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_compile(self):
if version.parse(torch.__version__) < version.parse("2.0"):
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
return
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.to("cuda")
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.controlnet.to(memory_format=torch.channels_last)
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
output = pipe(prompt, image, generator=generator, output_type="np")
image = output.images[0]
assert image.shape == (768, 512, 3)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
)
assert np.abs(expected_image - image).max() < 1e-1
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment