Unverified Commit cc59b056 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to return tuples (#334)



* add outputs for models

* add for pipelines

* finish schedulers

* better naming

* adapt tests as well

* replace dict access with . access

* make schedulers works

* finish

* correct readme

* make  bcp compatible

* up

* small fix

* finish

* more fixes

* more fixes

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/models/vae.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Adapt model outputs

* Apply more suggestions

* finish examples

* correct
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent daddd98b
......@@ -10,6 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -57,6 +58,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
if isinstance(prompt, str):
......@@ -83,7 +85,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_image = preprocess(init_image)
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image.to(self.device)).sample(generator=generator)
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# expand init_latents for batch_size
......@@ -158,7 +161,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
t = t.to(self.unet.dtype)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -167,13 +170,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents.to(self.vae.dtype))
image = self.vae.decode(latents.to(self.vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
......@@ -185,4 +188,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
......@@ -11,6 +11,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -72,6 +73,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
if isinstance(prompt, str):
......@@ -98,7 +100,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_image = preprocess_image(init_image).to(self.device)
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image).sample(generator=generator)
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size
......@@ -166,7 +170,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -174,7 +178,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
......@@ -182,7 +186,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
......@@ -194,4 +198,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
#!/usr/bin/env python3
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import KarrasVeScheduler
......@@ -35,8 +35,9 @@ class KarrasVePipeline(DiffusionPipeline):
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[Tuple, ImagePipelineOutput]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
......@@ -71,7 +72,7 @@ class KarrasVePipeline(DiffusionPipeline):
# 3. Predict the noise residual given the noise magnitude `sigma_hat`
# The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"]
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
# 4. Evaluate dx/dt at sigma_hat
# 5. Take Euler step from sigma to sigma_prev
......@@ -80,20 +81,23 @@ class KarrasVePipeline(DiffusionPipeline):
if sigma_prev != 0:
# 6. Apply 2nd order correction
# The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"]
model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
step_output = self.scheduler.step_correct(
model_output,
sigma_hat,
sigma_prev,
sample_hat,
step_output["prev_sample"],
step_output.prev_sample,
step_output["derivative"],
)
sample = step_output["prev_sample"]
sample = step_output.prev_sample
sample = (sample / 2 + 0.5).clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
image = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
image = self.numpy_to_pil(sample)
return {"sample": sample}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
......@@ -16,13 +16,13 @@
# and https://github.com/hojonathanho/diffusion
import math
from typing import Optional, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -116,7 +116,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
......@@ -174,7 +176,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
......
......@@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import Optional, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -135,7 +135,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray],
predict_epsilon=True,
generator=None,
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
......@@ -177,7 +179,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance
return {"prev_sample": pred_prev_sample}
if not return_dict:
return (pred_prev_sample,)
return SchedulerOutput(prev_sample=pred_prev_sample)
def add_noise(
self,
......
......@@ -13,15 +13,34 @@
# limitations under the License.
from typing import Union
from dataclasses import dataclass
from typing import Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class KarrasVeOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivate of predicted original image sample (x_0).
"""
prev_sample: torch.FloatTensor
derivative: torch.FloatTensor
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
......@@ -102,12 +121,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
pred_original_sample = sample_hat + sigma_hat * model_output
derivative = (sample_hat - pred_original_sample) / sigma_hat
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
return {"prev_sample": sample_prev, "derivative": derivative}
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
def step_correct(
self,
......@@ -117,11 +141,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_prev: Union[torch.FloatTensor, np.ndarray],
derivative: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
return {"prev_sample": sample_prev, "derivative": derivative_corr}
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Tuple, Union
import numpy as np
import torch
......@@ -20,7 +20,7 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
......@@ -100,7 +100,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4,
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
sigma = self.sigmas[timestep]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
......@@ -121,7 +122,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
)
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(self, original_samples, noise, timesteps):
sigmas = self.match_shape(self.sigmas[timesteps], noise)
......
......@@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import Union
from typing import Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -133,18 +133,21 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
def step_prk(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
......@@ -176,14 +179,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
self.counter += 1
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def step_plms(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
......@@ -226,7 +233,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
......
......@@ -15,13 +15,32 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
import warnings
from typing import Optional, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@dataclass
class SdeVeOutput(BaseOutput):
"""
Output class for the ScoreSdeVeScheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
"""
prev_sample: torch.FloatTensor
prev_sample_mean: torch.FloatTensor
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
......@@ -117,8 +136,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
):
) -> Union[SdeVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
......@@ -150,15 +170,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
return {"prev_sample": prev_sample, "prev_sample_mean": prev_sample_mean}
if not return_dict:
return (prev_sample, prev_sample_mean)
return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
):
) -> Union[SchedulerOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
......@@ -186,7 +210,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
......@@ -11,15 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Union
import numpy as np
import torch
from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@dataclass
class SchedulerOutput(BaseOutput):
"""
Base class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME
......
......@@ -33,6 +33,7 @@ from .import_utils import (
requires_backends,
)
from .logging import get_logger
from .outputs import BaseOutput
logger = get_logger(__name__)
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generic utilities
"""
import warnings
from collections import OrderedDict
from dataclasses import fields
from typing import Any, Tuple
import numpy as np
from .import_utils import is_torch_available
def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
"""
if is_torch_available():
import torch
if isinstance(x, torch.Tensor):
return True
return isinstance(x, np.ndarray)
class BaseOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
python dictionary.
<Tip warning={true}>
You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
before.
</Tip>
"""
def __post_init__(self):
class_fields = fields(self)
# Safety and consistency checks
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not is_tensor(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
warnings.warn(
"The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
" `'images'` instead.",
DeprecationWarning,
)
return inner_dict["images"]
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
# Will raise a KeyException if needed
super().__setitem__(key, value)
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
......@@ -15,6 +15,7 @@
import inspect
import tempfile
from typing import Dict, List, Tuple
import numpy as np
import torch
......@@ -39,12 +40,12 @@ class ModelTesterMixin:
with torch.no_grad():
image = model(**inputs_dict)
if isinstance(image, dict):
image = image["sample"]
image = image.sample
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image["sample"]
new_image = new_image.sample
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
......@@ -57,11 +58,11 @@ class ModelTesterMixin:
with torch.no_grad():
first = model(**inputs_dict)
if isinstance(first, dict):
first = first["sample"]
first = first.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second["sample"]
second = second.sample
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
......@@ -80,7 +81,7 @@ class ModelTesterMixin:
output = model(**inputs_dict)
if isinstance(output, dict):
output = output["sample"]
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
......@@ -122,12 +123,12 @@ class ModelTesterMixin:
output_1 = model(**inputs_dict)
if isinstance(output_1, dict):
output_1 = output_1["sample"]
output_1 = output_1.sample
output_2 = new_model(**inputs_dict)
if isinstance(output_2, dict):
output_2 = output_2["sample"]
output_2 = output_2.sample
self.assertEqual(output_1.shape, output_2.shape)
......@@ -140,7 +141,7 @@ class ModelTesterMixin:
output = model(**inputs_dict)
if isinstance(output, dict):
output = output["sample"]
output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
......@@ -157,9 +158,47 @@ class ModelTesterMixin:
output = model(**inputs_dict)
if isinstance(output, dict):
output = output["sample"]
output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model)
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
......@@ -77,7 +77,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# time_step = torch.tensor([10])
#
# with torch.no_grad():
# output = model(noise, time_step)["sample"]
# output = model(noise, time_step).sample
#
# output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
......@@ -129,7 +129,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)["sample"]
image = model(**self.dummy_input).sample
assert image is not None, "Make sure output is not None"
......@@ -147,7 +147,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output = model(noise, time_step).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
......@@ -258,7 +258,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output = model(noise, time_step).sample
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
......@@ -283,7 +283,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output = model(noise, time_step).sample
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
......
......@@ -87,7 +87,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image, sample_posterior=True)
output = model(image, sample_posterior=True).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
......
......@@ -85,7 +85,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image)
output = model(image).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
......
This diff is collapsed.
......@@ -14,6 +14,7 @@
# limitations under the License.
import tempfile
import unittest
from typing import Dict, List, Tuple
import numpy as np
import torch
......@@ -85,8 +86,8 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -114,9 +115,9 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
torch.manual_seed(0)
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -143,9 +144,9 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0)
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
torch.manual_seed(0)
new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -166,8 +167,8 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
......@@ -195,11 +196,64 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(residual, 0, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple, outputs_dict)
class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,)
......@@ -270,7 +324,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"]
pred_prev_sample = scheduler.step(residual, t, sample).prev_sample
# if t > 0:
# noise = self.dummy_sample_deter
......@@ -356,7 +410,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta)["prev_sample"]
sample = scheduler.step(residual, t, sample, eta).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
......@@ -401,13 +455,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -438,13 +492,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
......@@ -476,12 +530,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler.ets = dummy_past_residuals[:]
scheduler_pt.ets = dummy_past_residuals_pt[:]
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
......@@ -535,14 +589,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:]
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
......@@ -573,7 +627,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"]
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
......@@ -587,11 +641,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t)
sample = scheduler.step_prk(residual, i, sample)["prev_sample"]
sample = scheduler.step_prk(residual, i, sample).prev_sample
for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample)["prev_sample"]
sample = scheduler.step_plms(residual, i, sample).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
......@@ -664,13 +718,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
......@@ -689,13 +743,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
......@@ -732,13 +786,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
for _ in range(scheduler.correct_steps):
with torch.no_grad():
model_output = model(sample, sigma_t)
sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"]
sample = scheduler.step_correct(model_output, sample, **kwargs).prev_sample
with torch.no_grad():
model_output = model(sample, sigma_t)
output = scheduler.step_pred(model_output, t, sample, **kwargs)
sample, _ = output["prev_sample"], output["prev_sample_mean"]
sample, _ = output.prev_sample, output.prev_sample_mean
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
......@@ -763,8 +817,8 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
......@@ -66,7 +66,7 @@ class TrainingTests(unittest.TestCase):
for i in range(4):
optimizer.zero_grad()
ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"]
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i]).sample
loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i])
loss.backward()
optimizer.step()
......@@ -78,7 +78,7 @@ class TrainingTests(unittest.TestCase):
for i in range(4):
optimizer.zero_grad()
ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"]
ddim_noise_pred = model(ddim_noisy_images, timesteps[i]).sample
loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i])
loss.backward()
optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment