Unverified Commit f3937bc8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Refactor] Remove set_seed (#289)



* [Refactor] Remove set_seed and class attributes

* apply anton's suggestiosn

* fix

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* up

* update

* make style

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>

* make fix-copies

* make style

* make style and new copies
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent 384fcac6
...@@ -56,7 +56,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -56,7 +56,7 @@ class DDPMPipeline(DiffusionPipeline):
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t)["sample"]
# 2. compute previous image: x_t -> t_t-1 # 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image)["prev_sample"] image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
...@@ -30,7 +30,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -30,7 +30,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.unet model = self.unet
sample = torch.randn(*shape) * self.scheduler.config.sigma_max sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
sample = sample.to(self.device) sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
...@@ -42,11 +42,11 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -42,11 +42,11 @@ class ScoreSdeVePipeline(DiffusionPipeline):
# correction step # correction step
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):
model_output = self.unet(sample, sigma_t)["sample"] model_output = self.unet(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
# prediction step # prediction step
model_output = model(sample, sigma_t)["sample"] model_output = model(sample, sigma_t)["sample"]
output = self.scheduler.step_pred(model_output, t, sample) output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
......
...@@ -19,6 +19,7 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -19,6 +19,7 @@ class KarrasVePipeline(DiffusionPipeline):
differential equations." https://arxiv.org/abs/2011.13456 differential equations." https://arxiv.org/abs/2011.13456
""" """
# add type hints for linting
unet: UNet2DModel unet: UNet2DModel
scheduler: KarrasVeScheduler scheduler: KarrasVeScheduler
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit import warnings
from typing import Union from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -98,6 +98,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -98,6 +98,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_seed(self, seed): def set_seed(self, seed):
warnings.warn(
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
" generator instead.",
DeprecationWarning,
)
tensor_format = getattr(self, "tensor_format", "pt") tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np": if tensor_format == "np":
np.random.seed(seed) np.random.seed(seed)
...@@ -111,14 +116,14 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -111,14 +116,14 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
seed=None, generator: Optional[torch.Generator] = None,
**kwargs,
): ):
""" """
Predict the sample at the previous timestep by reversing the SDE. Predict the sample at the previous timestep by reversing the SDE.
""" """
if seed is not None: if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(seed) self.set_seed(kwargs["seed"])
# TODO(Patrick) non-PyTorch
if self.timesteps is None: if self.timesteps is None:
raise ValueError( raise ValueError(
...@@ -140,7 +145,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -140,7 +145,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
drift = drift - diffusion[:, None, None, None] ** 2 * model_output drift = drift - diffusion[:, None, None, None] ** 2 * model_output
# equation 6: sample noise for the diffusion term of # equation 6: sample noise for the diffusion term of
noise = self.randn_like(sample) noise = self.randn_like(sample, generator=generator)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise? # TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
...@@ -151,14 +156,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -151,14 +156,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
seed=None, generator: Optional[torch.Generator] = None,
**kwargs,
): ):
""" """
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep. after making the prediction for the previous timestep.
""" """
if seed is not None: if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(seed) self.set_seed(kwargs["seed"])
if self.timesteps is None: if self.timesteps is None:
raise ValueError( raise ValueError(
...@@ -167,7 +173,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -167,7 +173,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction # sample noise for correction
noise = self.randn_like(sample) noise = self.randn_like(sample, generator=generator)
# compute step size from the model_output, the noise, and the snr # compute step size from the model_output, the noise, and the snr
grad_norm = self.norm(model_output) grad_norm = self.norm(model_output)
......
# This file is autogenerated by the command `make fix-copies`, do not edit. # This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa # flake8: noqa
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
......
# This file is autogenerated by the command `make fix-copies`, do not edit. # This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa # flake8: noqa
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
......
...@@ -107,7 +107,7 @@ def create_dummy_files(): ...@@ -107,7 +107,7 @@ def create_dummy_files():
for backend, objects in backend_specific_objects.items(): for backend, objects in backend_specific_objects.items():
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]" backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
dummy_file += "# flake8: noqa\n" dummy_file += "# flake8: noqa\n\n"
dummy_file += "from ..utils import DummyObject, requires_backends\n\n" dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects]) dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file dummy_files[backend] = dummy_file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment