Unverified Commit ee71d9d0 authored by clarencechen's avatar clarencechen Committed by GitHub
Browse files

Add support for different model prediction types in DDIMInverseScheduler (#2619)



* Add support for different model prediction types in DDIMInverseScheduler
Resolve alpha_prod_t_prev index issue for final step of inversion

* Fix old bug introduced when prediction type is "sample"

* Add support for sample clipping for numerical stability and deprecate old kwarg

* Detach sample, alphas, betas

Derive predicted noise from model output before dist. regularization

Style cleanup

* Log loss for debugging

* Revert "Log loss for debugging"

This reverts commit 76ea9c856f99f4c8eca45a0b1801593bb982584b.

* Add comments

* Add inversion equivalence test

* Add expected data for Pix2PixZero pipeline tests with SD 2

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Remove cruft and add more explanatory comments

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 268ebcb0
...@@ -153,6 +153,8 @@ EXAMPLE_INVERT_DOC_STRING = """ ...@@ -153,6 +153,8 @@ EXAMPLE_INVERT_DOC_STRING = """
>>> source_embeds = pipeline.get_embeds(source_prompts) >>> source_embeds = pipeline.get_embeds(source_prompts)
>>> target_embeds = pipeline.get_embeds(target_prompts) >>> target_embeds = pipeline.get_embeds(target_prompts)
>>> # the latents can then be used to edit a real image >>> # the latents can then be used to edit a real image
>>> # when using Stable Diffusion 2 or other models that use v-prediction
>>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion
>>> image = pipeline( >>> image = pipeline(
... caption, ... caption,
...@@ -730,6 +732,23 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -730,6 +732,23 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
return latents return latents
def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
pred_type = self.inverse_scheduler.config.prediction_type
alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t
if pred_type == "epsilon":
return model_output
elif pred_type == "sample":
return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
elif pred_type == "v_prediction":
return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
)
def auto_corr_loss(self, hidden_states, generator=None): def auto_corr_loss(self, hidden_states, generator=None):
batch_size, channel, height, width = hidden_states.shape batch_size, channel, height, width = hidden_states.shape
if batch_size > 1: if batch_size > 1:
...@@ -1156,8 +1175,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1156,8 +1175,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# 7. Denoising loop where we obtain the cross-attention maps. # 7. Denoising loop where we obtain the cross-attention maps.
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
with self.progress_bar(total=num_inference_steps - 2) as progress_bar: with self.progress_bar(total=num_inference_steps - 1) as progress_bar:
for i, t in enumerate(timesteps[1:-1]): for i, t in enumerate(timesteps[:-1]):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
...@@ -1181,7 +1200,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1181,7 +1200,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
if lambda_auto_corr > 0: if lambda_auto_corr > 0:
for _ in range(num_auto_corr_rolls): for _ in range(num_auto_corr_rolls):
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
l_ac = self.auto_corr_loss(var, generator=generator)
# Derive epsilon from model output before regularizing to IID standard normal
var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
l_ac = self.auto_corr_loss(var_epsilon, generator=generator)
l_ac.backward() l_ac.backward()
grad = var.grad.detach() / num_auto_corr_rolls grad = var.grad.detach() / num_auto_corr_rolls
...@@ -1190,7 +1213,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1190,7 +1213,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
if lambda_kl > 0: if lambda_kl > 0:
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
l_kld = self.kl_divergence(var) # Derive epsilon from model output before regularizing to IID standard normal
var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
l_kld = self.kl_divergence(var_epsilon)
l_kld.backward() l_kld.backward()
grad = var.grad.detach() grad = var.grad.detach()
......
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput from diffusers.utils import BaseOutput, deprecate
@dataclass @dataclass
...@@ -96,15 +96,17 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -96,15 +96,17 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample for numerical stability.
set_alpha_to_one (`bool`, default `True`): clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_zero (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0`,
otherwise it uses the value of alpha at step 0. otherwise it uses the value of alpha at step `num_train_timesteps - 1`.
steps_offset (`int`, default `0`): steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in `set_alpha_to_zero=False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha
stable diffusion. product.
prediction_type (`str`, default `epsilon`, optional): prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
...@@ -122,10 +124,18 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -122,10 +124,18 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True, clip_sample: bool = True,
set_alpha_to_one: bool = True, set_alpha_to_zero: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
clip_sample_range: float = 1.0,
**kwargs,
): ):
if kwargs.get("set_alpha_to_one", None) is not None:
deprecation_message = (
"The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead."
)
deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False)
set_alpha_to_zero = kwargs["set_alpha_to_one"]
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -144,11 +154,12 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -144,11 +154,12 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# At every step in ddim, we are looking into the previous alphas_cumprod # At every step in inverted ddim, we are looking into the next alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0 # For the final step, there is no next alphas_cumprod, and the index is out of bounds
# `set_alpha_to_one` decides whether we set this parameter simply to one or # `set_alpha_to_zero` decides whether we set this parameter simply to zero
# whether we use the final alpha of the "non-previous" one. # in this case, self.step() just output the predicted noise
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # or whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1]
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
...@@ -157,6 +168,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -157,6 +168,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
...@@ -205,23 +217,52 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -205,23 +217,52 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
variance_noise: Optional[torch.FloatTensor] = None, variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]: ) -> Union[DDIMSchedulerOutput, Tuple]:
e_t = model_output # 1. get previous step value (=t+1)
x = sample
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
a_t = self.alphas_cumprod[timestep - 1] # 2. compute alphas, betas
a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod # change original implementation to exactly match noise levels for analogous forward process
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = (
self.alphas_cumprod[prev_timestep]
if prev_timestep < self.config.num_train_timesteps
else self.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt() # 4. Clip or threshold "predicted x_0"
if self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
dir_xt = (1.0 - a_prev).sqrt() * e_t # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
prev_sample = a_prev.sqrt() * pred_x0 + dir_xt # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if not return_dict: if not return_dict:
return (prev_sample, pred_x0) return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
...@@ -347,7 +347,6 @@ class InversionPipelineSlowTests(unittest.TestCase): ...@@ -347,7 +347,6 @@ class InversionPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
) )
pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
caption = "a photography of a cat with flowers" caption = "a photography of a cat with flowers"
...@@ -366,6 +365,28 @@ class InversionPipelineSlowTests(unittest.TestCase): ...@@ -366,6 +365,28 @@ class InversionPipelineSlowTests(unittest.TestCase):
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2 assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2
def test_stable_diffusion_2_pix2pix_inversion(self):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
caption = "a photography of a cat with flowers"
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
inv_latents = output[0]
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
assert inv_latents.shape == (1, 4, 64, 64)
expected_slice = np.array([0.7515, -0.2397, 0.4922, -0.9736, -0.7031, 0.4846, -1.0781, 1.1309, -0.6973])
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2
def test_stable_diffusion_pix2pix_full(self): def test_stable_diffusion_pix2pix_full(self):
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png # numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
expected_image = load_numpy( expected_image = load_numpy(
...@@ -375,7 +396,6 @@ class InversionPipelineSlowTests(unittest.TestCase): ...@@ -375,7 +396,6 @@ class InversionPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
) )
pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
caption = "a photography of a cat with flowers" caption = "a photography of a cat with flowers"
...@@ -407,3 +427,44 @@ class InversionPipelineSlowTests(unittest.TestCase): ...@@ -407,3 +427,44 @@ class InversionPipelineSlowTests(unittest.TestCase):
max_diff = np.abs(expected_image - image).mean() max_diff = np.abs(expected_image - image).mean()
assert max_diff < 0.05 assert max_diff < 0.05
def test_stable_diffusion_2_pix2pix_full(self):
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog_2.png
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog_2.npy"
)
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
caption = "a photography of a cat with flowers"
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
output = pipe.invert(caption, image=self.raw_image, generator=generator)
inv_latents = output[0]
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
target_prompts = 4 * ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
source_embeds = pipe.get_embeds(source_prompts)
target_embeds = pipe.get_embeds(target_prompts)
image = pipe(
caption,
source_embeds=source_embeds,
target_embeds=target_embeds,
num_inference_steps=125,
cross_attention_guidance_amount=0.015,
generator=generator,
latents=inv_latents,
negative_prompt=caption,
output_type="np",
).images
max_diff = np.abs(expected_image - image).mean()
assert max_diff < 0.05
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment