Unverified Commit cc59b056 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to return tuples (#334)



* add outputs for models

* add for pipelines

* finish schedulers

* better naming

* adapt tests as well

* replace dict access with . access

* make schedulers works

* finish

* correct readme

* make  bcp compatible

* up

* small fix

* finish

* more fixes

* more fixes

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/models/vae.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Adapt model outputs

* Apply more suggestions

* finish examples

* correct
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent daddd98b
......@@ -10,6 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -57,6 +58,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
if isinstance(prompt, str):
......@@ -83,7 +85,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_image = preprocess(init_image)
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image.to(self.device)).sample(generator=generator)
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# expand init_latents for batch_size
......@@ -158,7 +161,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
t = t.to(self.unet.dtype)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -167,13 +170,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents.to(self.vae.dtype))
image = self.vae.decode(latents.to(self.vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
......@@ -185,4 +188,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
......@@ -11,6 +11,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -72,6 +73,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
if isinstance(prompt, str):
......@@ -98,7 +100,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_image = preprocess_image(init_image).to(self.device)
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image).sample(generator=generator)
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size
......@@ -166,7 +170,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -174,7 +178,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
......@@ -182,7 +186,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
......@@ -194,4 +198,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
#!/usr/bin/env python3
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import KarrasVeScheduler
......@@ -35,8 +35,9 @@ class KarrasVePipeline(DiffusionPipeline):
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[Tuple, ImagePipelineOutput]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
......@@ -71,7 +72,7 @@ class KarrasVePipeline(DiffusionPipeline):
# 3. Predict the noise residual given the noise magnitude `sigma_hat`
# The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"]
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
# 4. Evaluate dx/dt at sigma_hat
# 5. Take Euler step from sigma to sigma_prev
......@@ -80,20 +81,23 @@ class KarrasVePipeline(DiffusionPipeline):
if sigma_prev != 0:
# 6. Apply 2nd order correction
# The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"]
model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
step_output = self.scheduler.step_correct(
model_output,
sigma_hat,
sigma_prev,
sample_hat,
step_output["prev_sample"],
step_output.prev_sample,
step_output["derivative"],
)
sample = step_output["prev_sample"]
sample = step_output.prev_sample
sample = (sample / 2 + 0.5).clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
image = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
image = self.numpy_to_pil(sample)
return {"sample": sample}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
......@@ -16,13 +16,13 @@
# and https://github.com/hojonathanho/diffusion
import math
from typing import Optional, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -116,7 +116,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
......@@ -174,7 +176,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
......
......@@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import Optional, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -135,7 +135,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray],
predict_epsilon=True,
generator=None,
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
......@@ -177,7 +179,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance
return {"prev_sample": pred_prev_sample}
if not return_dict:
return (pred_prev_sample,)
return SchedulerOutput(prev_sample=pred_prev_sample)
def add_noise(
self,
......
......@@ -13,15 +13,34 @@
# limitations under the License.
from typing import Union
from dataclasses import dataclass
from typing import Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class KarrasVeOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivate of predicted original image sample (x_0).
"""
prev_sample: torch.FloatTensor
derivative: torch.FloatTensor
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
......@@ -102,12 +121,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
pred_original_sample = sample_hat + sigma_hat * model_output
derivative = (sample_hat - pred_original_sample) / sigma_hat
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
return {"prev_sample": sample_prev, "derivative": derivative}
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
def step_correct(
self,
......@@ -117,11 +141,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_prev: Union[torch.FloatTensor, np.ndarray],
derivative: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
return {"prev_sample": sample_prev, "derivative": derivative_corr}
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Tuple, Union
import numpy as np
import torch
......@@ -20,7 +20,7 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
......@@ -100,7 +100,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4,
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
sigma = self.sigmas[timestep]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
......@@ -121,7 +122,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
)
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(self, original_samples, noise, timesteps):
sigmas = self.match_shape(self.sigmas[timesteps], noise)
......
......@@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import Union
from typing import Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -133,18 +133,21 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
def step_prk(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
......@@ -176,14 +179,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
self.counter += 1
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def step_plms(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
......@@ -226,7 +233,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
......
......@@ -15,13 +15,32 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
import warnings
from typing import Optional, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@dataclass
class SdeVeOutput(BaseOutput):
"""
Output class for the ScoreSdeVeScheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
"""
prev_sample: torch.FloatTensor
prev_sample_mean: torch.FloatTensor
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
......@@ -117,8 +136,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
):
) -> Union[SdeVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
......@@ -150,15 +170,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
return {"prev_sample": prev_sample, "prev_sample_mean": prev_sample_mean}
if not return_dict:
return (prev_sample, prev_sample_mean)
return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
):
) -> Union[SchedulerOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
......@@ -186,7 +210,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
......@@ -11,15 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Union
import numpy as np
import torch
from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@dataclass
class SchedulerOutput(BaseOutput):
"""
Base class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME
......
......@@ -33,6 +33,7 @@ from .import_utils import (
requires_backends,
)
from .logging import get_logger
from .outputs import BaseOutput
logger = get_logger(__name__)
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generic utilities
"""
import warnings
from collections import OrderedDict
from dataclasses import fields
from typing import Any, Tuple
import numpy as np
from .import_utils import is_torch_available
def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
"""
if is_torch_available():
import torch
if isinstance(x, torch.Tensor):
return True
return isinstance(x, np.ndarray)
class BaseOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
python dictionary.
<Tip warning={true}>
You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
before.
</Tip>
"""
def __post_init__(self):
class_fields = fields(self)
# Safety and consistency checks
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not is_tensor(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
warnings.warn(
"The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
" `'images'` instead.",
DeprecationWarning,
)
return inner_dict["images"]
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
# Will raise a KeyException if needed
super().__setitem__(key, value)
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
......@@ -15,6 +15,7 @@
import inspect
import tempfile
from typing import Dict, List, Tuple
import numpy as np
import torch
......@@ -39,12 +40,12 @@ class ModelTesterMixin:
with torch.no_grad():
image = model(**inputs_dict)
if isinstance(image, dict):
image = image["sample"]
image = image.sample
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image["sample"]
new_image = new_image.sample
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
......@@ -57,11 +58,11 @@ class ModelTesterMixin:
with torch.no_grad():
first = model(**inputs_dict)
if isinstance(first, dict):
first = first["sample"]
first = first.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second["sample"]
second = second.sample
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
......@@ -80,7 +81,7 @@ class ModelTesterMixin:
output = model(**inputs_dict)
if isinstance(output, dict):
output = output["sample"]
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
......@@ -122,12 +123,12 @@ class ModelTesterMixin:
output_1 = model(**inputs_dict)
if isinstance(output_1, dict):
output_1 = output_1["sample"]
output_1 = output_1.sample
output_2 = new_model(**inputs_dict)
if isinstance(output_2, dict):
output_2 = output_2["sample"]
output_2 = output_2.sample
self.assertEqual(output_1.shape, output_2.shape)
......@@ -140,7 +141,7 @@ class ModelTesterMixin:
output = model(**inputs_dict)
if isinstance(output, dict):
output = output["sample"]
output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
......@@ -157,9 +158,47 @@ class ModelTesterMixin:
output = model(**inputs_dict)
if isinstance(output, dict):
output = output["sample"]
output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model)
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
......@@ -77,7 +77,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# time_step = torch.tensor([10])
#
# with torch.no_grad():
# output = model(noise, time_step)["sample"]
# output = model(noise, time_step).sample
#
# output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
......@@ -129,7 +129,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)["sample"]
image = model(**self.dummy_input).sample
assert image is not None, "Make sure output is not None"
......@@ -147,7 +147,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output = model(noise, time_step).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
......@@ -258,7 +258,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output = model(noise, time_step).sample
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
......@@ -283,7 +283,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output = model(noise, time_step).sample
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
......
......@@ -87,7 +87,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image, sample_posterior=True)
output = model(image, sample_posterior=True).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
......
......@@ -85,7 +85,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image)
output = model(image).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
......
......@@ -67,12 +67,12 @@ def test_progress_bar(capsys):
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler).to(torch_device)
ddpm(output_type="numpy")["sample"]
ddpm(output_type="numpy").images
captured = capsys.readouterr()
assert "10/10" in captured.err, "Progress bar has to be displayed"
ddpm.set_progress_bar_config(disable=True)
ddpm(output_type="numpy")["sample"]
ddpm(output_type="numpy").images
captured = capsys.readouterr()
assert captured.err == "", "Progress bar should be disabled"
......@@ -196,15 +196,20 @@ class PipelineFastTests(unittest.TestCase):
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_pndm_cifar10(self):
unet = self.dummy_uncond_unet
......@@ -213,14 +218,20 @@ class PipelineFastTests(unittest.TestCase):
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device)
pndm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images
generator = torch.manual_seed(0)
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy")["sample"]
image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_ldm_text2img(self):
unet = self.dummy_cond_unet
......@@ -239,11 +250,23 @@ class PipelineFastTests(unittest.TestCase):
"sample"
]
generator = torch.manual_seed(0)
image_from_tuple = ldm(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="numpy",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
......@@ -274,16 +297,28 @@ class PipelineFastTests(unittest.TestCase):
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output.images
image = output["sample"]
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
......@@ -310,13 +345,25 @@ class PipelineFastTests(unittest.TestCase):
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output["sample"]
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
......@@ -343,13 +390,25 @@ class PipelineFastTests(unittest.TestCase):
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output["sample"]
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_score_sde_ve_pipeline(self):
unet = self.dummy_uncond_unet
......@@ -360,14 +419,19 @@ class PipelineFastTests(unittest.TestCase):
sde_ve.set_progress_bar_config(disable=None)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=2, output_type="numpy")["sample"]
image = sde_ve(num_inference_steps=2, output_type="numpy").images
torch.manual_seed(0)
image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_ldm_uncond(self):
unet = self.dummy_uncond_unet
......@@ -379,13 +443,18 @@ class PipelineFastTests(unittest.TestCase):
ldm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_karras_ve_pipeline(self):
unet = self.dummy_uncond_unet
......@@ -396,12 +465,18 @@ class PipelineFastTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy")["sample"]
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
......@@ -437,13 +512,26 @@ class PipelineFastTests(unittest.TestCase):
init_image=init_image,
)
image = output["sample"]
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
......@@ -479,14 +567,27 @@ class PipelineFastTests(unittest.TestCase):
output_type="np",
init_image=init_image,
)
image = output.images
image = output["sample"]
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
return_dict=False,
)
image_from_tuple = output[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
......@@ -525,13 +626,27 @@ class PipelineFastTests(unittest.TestCase):
mask_image=mask_image,
)
image = output["sample"]
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
class PipelineTesterMixin(unittest.TestCase):
......@@ -565,9 +680,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
new_image = new_ddpm(generator=generator, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
......@@ -586,9 +701,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
......@@ -610,9 +725,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"]
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
......@@ -625,17 +740,17 @@ class PipelineTesterMixin(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
images = pipe(generator=generator, output_type="numpy")["sample"]
images = pipe(generator=generator, output_type="numpy").images
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="pil")["sample"]
images = pipe(generator=generator, output_type="pil").images
assert isinstance(images, list)
assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image)
# use PIL by default
images = pipe(generator=generator)["sample"]
images = pipe(generator=generator).images
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
......@@ -652,7 +767,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
......@@ -672,7 +787,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
......@@ -692,7 +807,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
......@@ -711,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase):
pndm.to(torch_device)
pndm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pndm(generator=generator, output_type="numpy")["sample"]
image = pndm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
......@@ -745,7 +860,7 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
......@@ -768,7 +883,7 @@ class PipelineTesterMixin(unittest.TestCase):
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
)
image = output["sample"]
image = output.images
image_slice = image[0, -3:, -3:, -1]
......@@ -797,7 +912,7 @@ class PipelineTesterMixin(unittest.TestCase):
with torch.autocast("cuda"):
output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
image = output["sample"]
image = output.images
image_slice = image[0, -3:, -3:, -1]
......@@ -817,7 +932,7 @@ class PipelineTesterMixin(unittest.TestCase):
sde_ve.set_progress_bar_config(disable=None)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
image = sde_ve(num_inference_steps=300, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
......@@ -833,7 +948,7 @@ class PipelineTesterMixin(unittest.TestCase):
ldm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
......@@ -857,10 +972,10 @@ class PipelineTesterMixin(unittest.TestCase):
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
ddpm_image = ddpm(generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
......@@ -882,7 +997,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
......@@ -903,7 +1018,7 @@ class PipelineTesterMixin(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
......@@ -974,9 +1089,8 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[
"sample"
][0]
output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)
image = output.images[0]
expected_array = np.array(output_image)
sampled_array = np.array(image)
......@@ -1008,7 +1122,7 @@ class PipelineTesterMixin(unittest.TestCase):
strength=0.75,
guidance_scale=7.5,
generator=generator,
)["sample"][0]
).images[0]
expected_array = np.array(output_image)
sampled_array = np.array(image)
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import tempfile
import unittest
from typing import Dict, List, Tuple
import numpy as np
import torch
......@@ -85,8 +86,8 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -114,9 +115,9 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
torch.manual_seed(0)
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -143,9 +144,9 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0)
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
torch.manual_seed(0)
new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -166,8 +167,8 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
......@@ -195,11 +196,64 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(residual, 0, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple, outputs_dict)
class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,)
......@@ -270,7 +324,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"]
pred_prev_sample = scheduler.step(residual, t, sample).prev_sample
# if t > 0:
# noise = self.dummy_sample_deter
......@@ -356,7 +410,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta)["prev_sample"]
sample = scheduler.step(residual, t, sample, eta).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
......@@ -401,13 +455,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -438,13 +492,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -476,12 +530,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler.ets = dummy_past_residuals[:]
scheduler_pt.ets = dummy_past_residuals_pt[:]
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
......@@ -535,14 +589,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:]
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
......@@ -573,7 +627,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"]
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
......@@ -587,11 +641,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t)
sample = scheduler.step_prk(residual, i, sample)["prev_sample"]
sample = scheduler.step_prk(residual, i, sample).prev_sample
for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample)["prev_sample"]
sample = scheduler.step_plms(residual, i, sample).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
......@@ -664,13 +718,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
......@@ -689,13 +743,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
......@@ -732,13 +786,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
for _ in range(scheduler.correct_steps):
with torch.no_grad():
model_output = model(sample, sigma_t)
sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"]
sample = scheduler.step_correct(model_output, sample, **kwargs).prev_sample
with torch.no_grad():
model_output = model(sample, sigma_t)
output = scheduler.step_pred(model_output, t, sample, **kwargs)
sample, _ = output["prev_sample"], output["prev_sample_mean"]
sample, _ = output.prev_sample, output.prev_sample_mean
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
......@@ -763,8 +817,8 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
......@@ -66,7 +66,7 @@ class TrainingTests(unittest.TestCase):
for i in range(4):
optimizer.zero_grad()
ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"]
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i]).sample
loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i])
loss.backward()
optimizer.step()
......@@ -78,7 +78,7 @@ class TrainingTests(unittest.TestCase):
for i in range(4):
optimizer.zero_grad()
ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"]
ddim_noise_pred = model(ddim_noisy_images, timesteps[i]).sample
loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i])
loss.backward()
optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment