Unverified Commit cc59b056 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to return tuples (#334)



* add outputs for models

* add for pipelines

* finish schedulers

* better naming

* adapt tests as well

* replace dict access with . access

* make schedulers works

* finish

* correct readme

* make  bcp compatible

* up

* small fix

* finish

* more fixes

* more fixes

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/models/vae.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Adapt model outputs

* Apply more suggestions

* finish examples

* correct
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent daddd98b
...@@ -10,6 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -10,6 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -57,6 +58,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -57,6 +58,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True,
): ):
if isinstance(prompt, str): if isinstance(prompt, str):
...@@ -83,7 +85,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -83,7 +85,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_image = preprocess(init_image) init_image = preprocess(init_image)
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image.to(self.device)).sample(generator=generator) init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
# expand init_latents for batch_size # expand init_latents for batch_size
...@@ -158,7 +161,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -158,7 +161,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
t = t.to(self.unet.dtype) t = t.to(self.unet.dtype)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -167,13 +170,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -167,13 +170,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"] latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
else: else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents.to(self.vae.dtype)) image = self.vae.decode(latents.to(self.vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
...@@ -185,4 +188,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -185,4 +188,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept} if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
...@@ -11,6 +11,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -11,6 +11,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -72,6 +73,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -72,6 +73,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True,
): ):
if isinstance(prompt, str): if isinstance(prompt, str):
...@@ -98,7 +100,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -98,7 +100,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_image = preprocess_image(init_image).to(self.device) init_image = preprocess_image(init_image).to(self.device)
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image).sample(generator=generator) init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size # Expand init_latents for batch_size
...@@ -166,7 +170,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -166,7 +170,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -174,7 +178,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -174,7 +178,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking # masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
...@@ -182,7 +186,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -182,7 +186,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents) image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
...@@ -194,4 +198,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -194,4 +198,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept} if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
#!/usr/bin/env python3 #!/usr/bin/env python3
import warnings import warnings
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import KarrasVeScheduler from ...schedulers import KarrasVeScheduler
...@@ -35,8 +35,9 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -35,8 +35,9 @@ class KarrasVePipeline(DiffusionPipeline):
num_inference_steps: int = 50, num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs, **kwargs,
): ) -> Union[Tuple, ImagePipelineOutput]:
if "torch_device" in kwargs: if "torch_device" in kwargs:
device = kwargs.pop("torch_device") device = kwargs.pop("torch_device")
warnings.warn( warnings.warn(
...@@ -71,7 +72,7 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -71,7 +72,7 @@ class KarrasVePipeline(DiffusionPipeline):
# 3. Predict the noise residual given the noise magnitude `sigma_hat` # 3. Predict the noise residual given the noise magnitude `sigma_hat`
# The model inputs and output are adjusted by following eq. (213) in [1]. # The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"] model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
# 4. Evaluate dx/dt at sigma_hat # 4. Evaluate dx/dt at sigma_hat
# 5. Take Euler step from sigma to sigma_prev # 5. Take Euler step from sigma to sigma_prev
...@@ -80,20 +81,23 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -80,20 +81,23 @@ class KarrasVePipeline(DiffusionPipeline):
if sigma_prev != 0: if sigma_prev != 0:
# 6. Apply 2nd order correction # 6. Apply 2nd order correction
# The model inputs and output are adjusted by following eq. (213) in [1]. # The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"] model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
step_output = self.scheduler.step_correct( step_output = self.scheduler.step_correct(
model_output, model_output,
sigma_hat, sigma_hat,
sigma_prev, sigma_prev,
sample_hat, sample_hat,
step_output["prev_sample"], step_output.prev_sample,
step_output["derivative"], step_output["derivative"],
) )
sample = step_output["prev_sample"] sample = step_output.prev_sample
sample = (sample / 2 + 0.5).clamp(0, 1) sample = (sample / 2 + 0.5).clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy() image = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
sample = self.numpy_to_pil(sample) image = self.numpy_to_pil(sample)
return {"sample": sample} if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
# and https://github.com/hojonathanho/diffusion # and https://github.com/hojonathanho/diffusion
import math import math
from typing import Optional, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...@@ -116,7 +116,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -116,7 +116,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
eta: float = 0.0, eta: float = 0.0,
use_clipped_model_output: bool = False, use_clipped_model_output: bool = False,
generator=None, generator=None,
): return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.num_inference_steps is None: if self.num_inference_steps is None:
raise ValueError( raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
...@@ -174,7 +176,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -174,7 +176,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance prev_sample = prev_sample + variance
return {"prev_sample": prev_sample} if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise( def add_noise(
self, self,
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
from typing import Optional, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...@@ -135,7 +135,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -135,7 +135,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
predict_epsilon=True, predict_epsilon=True,
generator=None, generator=None,
): return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
...@@ -177,7 +179,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -177,7 +179,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
return {"prev_sample": pred_prev_sample} if not return_dict:
return (pred_prev_sample,)
return SchedulerOutput(prev_sample=pred_prev_sample)
def add_noise( def add_noise(
self, self,
......
...@@ -13,15 +13,34 @@ ...@@ -13,15 +13,34 @@
# limitations under the License. # limitations under the License.
from typing import Union from dataclasses import dataclass
from typing import Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
@dataclass
class KarrasVeOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivate of predicted original image sample (x_0).
"""
prev_sample: torch.FloatTensor
derivative: torch.FloatTensor
class KarrasVeScheduler(SchedulerMixin, ConfigMixin): class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
""" """
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
...@@ -102,12 +121,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -102,12 +121,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sigma_hat: float, sigma_hat: float,
sigma_prev: float, sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray], sample_hat: Union[torch.FloatTensor, np.ndarray],
): return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
pred_original_sample = sample_hat + sigma_hat * model_output pred_original_sample = sample_hat + sigma_hat * model_output
derivative = (sample_hat - pred_original_sample) / sigma_hat derivative = (sample_hat - pred_original_sample) / sigma_hat
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
return {"prev_sample": sample_prev, "derivative": derivative} if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
def step_correct( def step_correct(
self, self,
...@@ -117,11 +141,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,11 +141,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat: Union[torch.FloatTensor, np.ndarray], sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_prev: Union[torch.FloatTensor, np.ndarray], sample_prev: Union[torch.FloatTensor, np.ndarray],
derivative: Union[torch.FloatTensor, np.ndarray], derivative: Union[torch.FloatTensor, np.ndarray],
): return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
pred_original_sample = sample_prev + sigma_prev * model_output pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
return {"prev_sample": sample_prev, "derivative": derivative_corr}
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
def add_noise(self, original_samples, noise, timesteps): def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError() raise NotImplementedError()
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Union from typing import Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from scipy import integrate from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin, SchedulerOutput
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...@@ -100,7 +100,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -100,7 +100,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4, order: int = 4,
): return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
sigma = self.sigmas[timestep] sigma = self.sigmas[timestep]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
...@@ -121,7 +122,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -121,7 +122,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
) )
return {"prev_sample": prev_sample} if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(self, original_samples, noise, timesteps): def add_noise(self, original_samples, noise, timesteps):
sigmas = self.match_shape(self.sigmas[timesteps], noise) sigmas = self.match_shape(self.sigmas[timesteps], noise)
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
from typing import Union from typing import Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...@@ -133,18 +133,21 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -133,18 +133,21 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
): return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample) return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
else: else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample) return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
def step_prk( def step_prk(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
): return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
""" """
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation. solution to the differential equation.
...@@ -176,14 +179,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -176,14 +179,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
self.counter += 1 self.counter += 1
return {"prev_sample": prev_sample} if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def step_plms( def step_plms(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
): return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
""" """
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution. times to approximate the solution.
...@@ -226,7 +233,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -226,7 +233,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1 self.counter += 1
return {"prev_sample": prev_sample} if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
......
...@@ -15,13 +15,32 @@ ...@@ -15,13 +15,32 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
import warnings import warnings
from typing import Optional, Union from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@dataclass
class SdeVeOutput(BaseOutput):
"""
Output class for the ScoreSdeVeScheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
"""
prev_sample: torch.FloatTensor
prev_sample_mean: torch.FloatTensor
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
...@@ -117,8 +136,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,8 +136,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs, **kwargs,
): ) -> Union[SdeVeOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Predict the sample at the previous timestep by reversing the SDE.
""" """
...@@ -150,15 +170,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -150,15 +170,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# TODO is the variable diffusion the correct scaling term for the noise? # TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
return {"prev_sample": prev_sample, "prev_sample_mean": prev_sample_mean} if not return_dict:
return (prev_sample, prev_sample_mean)
return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
def step_correct( def step_correct(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs, **kwargs,
): ) -> Union[SchedulerOutput, Tuple]:
""" """
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep. after making the prediction for the previous timestep.
...@@ -186,7 +210,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -186,7 +210,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
prev_sample_mean = sample + step_size[:, None, None, None] * model_output prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
return {"prev_sample": prev_sample} if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
...@@ -11,15 +11,32 @@ ...@@ -11,15 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Union from typing import Union
import numpy as np import numpy as np
import torch import torch
from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json" SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@dataclass
class SchedulerOutput(BaseOutput):
"""
Base class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class SchedulerMixin: class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
......
...@@ -33,6 +33,7 @@ from .import_utils import ( ...@@ -33,6 +33,7 @@ from .import_utils import (
requires_backends, requires_backends,
) )
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput
logger = get_logger(__name__) logger = get_logger(__name__)
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generic utilities
"""
import warnings
from collections import OrderedDict
from dataclasses import fields
from typing import Any, Tuple
import numpy as np
from .import_utils import is_torch_available
def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
"""
if is_torch_available():
import torch
if isinstance(x, torch.Tensor):
return True
return isinstance(x, np.ndarray)
class BaseOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
python dictionary.
<Tip warning={true}>
You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
before.
</Tip>
"""
def __post_init__(self):
class_fields = fields(self)
# Safety and consistency checks
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not is_tensor(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
warnings.warn(
"The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
" `'images'` instead.",
DeprecationWarning,
)
return inner_dict["images"]
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
# Will raise a KeyException if needed
super().__setitem__(key, value)
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import inspect import inspect
import tempfile import tempfile
from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -39,12 +40,12 @@ class ModelTesterMixin: ...@@ -39,12 +40,12 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
image = model(**inputs_dict) image = model(**inputs_dict)
if isinstance(image, dict): if isinstance(image, dict):
image = image["sample"] image = image.sample
new_image = new_model(**inputs_dict) new_image = new_model(**inputs_dict)
if isinstance(new_image, dict): if isinstance(new_image, dict):
new_image = new_image["sample"] new_image = new_image.sample
max_diff = (image - new_image).abs().sum().item() max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
...@@ -57,11 +58,11 @@ class ModelTesterMixin: ...@@ -57,11 +58,11 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
first = model(**inputs_dict) first = model(**inputs_dict)
if isinstance(first, dict): if isinstance(first, dict):
first = first["sample"] first = first.sample
second = model(**inputs_dict) second = model(**inputs_dict)
if isinstance(second, dict): if isinstance(second, dict):
second = second["sample"] second = second.sample
out_1 = first.cpu().numpy() out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy() out_2 = second.cpu().numpy()
...@@ -80,7 +81,7 @@ class ModelTesterMixin: ...@@ -80,7 +81,7 @@ class ModelTesterMixin:
output = model(**inputs_dict) output = model(**inputs_dict)
if isinstance(output, dict): if isinstance(output, dict):
output = output["sample"] output = output.sample
self.assertIsNotNone(output) self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape expected_shape = inputs_dict["sample"].shape
...@@ -122,12 +123,12 @@ class ModelTesterMixin: ...@@ -122,12 +123,12 @@ class ModelTesterMixin:
output_1 = model(**inputs_dict) output_1 = model(**inputs_dict)
if isinstance(output_1, dict): if isinstance(output_1, dict):
output_1 = output_1["sample"] output_1 = output_1.sample
output_2 = new_model(**inputs_dict) output_2 = new_model(**inputs_dict)
if isinstance(output_2, dict): if isinstance(output_2, dict):
output_2 = output_2["sample"] output_2 = output_2.sample
self.assertEqual(output_1.shape, output_2.shape) self.assertEqual(output_1.shape, output_2.shape)
...@@ -140,7 +141,7 @@ class ModelTesterMixin: ...@@ -140,7 +141,7 @@ class ModelTesterMixin:
output = model(**inputs_dict) output = model(**inputs_dict)
if isinstance(output, dict): if isinstance(output, dict):
output = output["sample"] output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
...@@ -157,9 +158,47 @@ class ModelTesterMixin: ...@@ -157,9 +158,47 @@ class ModelTesterMixin:
output = model(**inputs_dict) output = model(**inputs_dict)
if isinstance(output, dict): if isinstance(output, dict):
output = output["sample"] output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
ema_model.step(model) ema_model.step(model)
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
...@@ -77,7 +77,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -77,7 +77,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# time_step = torch.tensor([10]) # time_step = torch.tensor([10])
# #
# with torch.no_grad(): # with torch.no_grad():
# output = model(noise, time_step)["sample"] # output = model(noise, time_step).sample
# #
# output_slice = output[0, -1, -3:, -3:].flatten() # output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
...@@ -129,7 +129,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -129,7 +129,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device) model.to(torch_device)
image = model(**self.dummy_input)["sample"] image = model(**self.dummy_input).sample
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
...@@ -147,7 +147,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -147,7 +147,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device) time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step)["sample"] output = model(noise, time_step).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu() output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off # fmt: off
...@@ -258,7 +258,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -258,7 +258,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step)["sample"] output = model(noise, time_step).sample
output_slice = output[0, -3:, -3:, -1].flatten().cpu() output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off # fmt: off
...@@ -283,7 +283,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -283,7 +283,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step)["sample"] output = model(noise, time_step).sample
output_slice = output[0, -3:, -3:, -1].flatten().cpu() output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off # fmt: off
......
...@@ -87,7 +87,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -87,7 +87,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device) image = image.to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(image, sample_posterior=True) output = model(image, sample_posterior=True).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu() output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off # fmt: off
......
...@@ -85,7 +85,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -85,7 +85,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device) image = image.to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(image) output = model(image).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu() output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off # fmt: off
......
This diff is collapsed.
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import tempfile import tempfile
import unittest import unittest
from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -85,8 +86,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -85,8 +86,8 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -114,9 +115,9 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -114,9 +115,9 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0) torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -143,9 +144,9 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -143,9 +144,9 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0) torch.manual_seed(0)
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -166,8 +167,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -166,8 +167,8 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"] output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
...@@ -195,11 +196,64 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -195,11 +196,64 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(residual, 0, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple, outputs_dict)
class DDPMSchedulerTest(SchedulerCommonTest): class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,) scheduler_classes = (DDPMScheduler,)
...@@ -270,7 +324,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -270,7 +324,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
residual = model(sample, t) residual = model(sample, t)
# 2. predict previous mean of sample x_t-1 # 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"] pred_prev_sample = scheduler.step(residual, t, sample).prev_sample
# if t > 0: # if t > 0:
# noise = self.dummy_sample_deter # noise = self.dummy_sample_deter
...@@ -356,7 +410,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -356,7 +410,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for t in scheduler.timesteps: for t in scheduler.timesteps:
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta)["prev_sample"] sample = scheduler.step(residual, t, sample, eta).prev_sample
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
...@@ -401,13 +455,13 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -401,13 +455,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -438,13 +492,13 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -438,13 +492,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
# copy over dummy past residual (must be after setting timesteps) # copy over dummy past residual (must be after setting timesteps)
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -476,12 +530,12 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -476,12 +530,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler.ets = dummy_past_residuals[:] scheduler.ets = dummy_past_residuals[:]
scheduler_pt.ets = dummy_past_residuals_pt[:] scheduler_pt.ets = dummy_past_residuals_pt[:]
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"] output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"] output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
...@@ -535,14 +589,14 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -535,14 +589,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:] scheduler.ets = dummy_past_residuals[:]
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"] output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"] output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"] output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"] output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
...@@ -573,7 +627,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -573,7 +627,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"] scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
...@@ -587,11 +641,11 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -587,11 +641,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for i, t in enumerate(scheduler.prk_timesteps): for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step_prk(residual, i, sample)["prev_sample"] sample = scheduler.step_prk(residual, i, sample).prev_sample
for i, t in enumerate(scheduler.plms_timesteps): for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample)["prev_sample"] sample = scheduler.step_plms(residual, i, sample).prev_sample
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
...@@ -664,13 +718,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -664,13 +718,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"] output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
...@@ -689,13 +743,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -689,13 +743,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"] output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
...@@ -732,13 +786,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -732,13 +786,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
for _ in range(scheduler.correct_steps): for _ in range(scheduler.correct_steps):
with torch.no_grad(): with torch.no_grad():
model_output = model(sample, sigma_t) model_output = model(sample, sigma_t)
sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"] sample = scheduler.step_correct(model_output, sample, **kwargs).prev_sample
with torch.no_grad(): with torch.no_grad():
model_output = model(sample, sigma_t) model_output = model(sample, sigma_t)
output = scheduler.step_pred(model_output, t, sample, **kwargs) output = scheduler.step_pred(model_output, t, sample, **kwargs)
sample, _ = output["prev_sample"], output["prev_sample_mean"] sample, _ = output.prev_sample, output.prev_sample_mean
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
...@@ -763,8 +817,8 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -763,8 +817,8 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs)["prev_sample"] output_0 = scheduler.step_pred(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"] output_1 = scheduler.step_pred(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
...@@ -66,7 +66,7 @@ class TrainingTests(unittest.TestCase): ...@@ -66,7 +66,7 @@ class TrainingTests(unittest.TestCase):
for i in range(4): for i in range(4):
optimizer.zero_grad() optimizer.zero_grad()
ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i]) ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"] ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i]).sample
loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i]) loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i])
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -78,7 +78,7 @@ class TrainingTests(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TrainingTests(unittest.TestCase):
for i in range(4): for i in range(4):
optimizer.zero_grad() optimizer.zero_grad()
ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i]) ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"] ddim_noise_pred = model(ddim_noisy_images, timesteps[i]).sample
loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i]) loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i])
loss.backward() loss.backward()
optimizer.step() optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment