Unverified Commit 726aba08 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Pytorch] pytorch only timesteps (#724)



* pytorch timesteps

* style

* get rid of if-else

* fix test
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 60c9634a
...@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher ...@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
To this end, the design of schedulers is such that: To this end, the design of schedulers is such that:
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality. - Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists). - Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
## API ## API
......
...@@ -278,11 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -278,11 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# Some schedulers like PNDM have timesteps as arrays # Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand # It's more optimized to move all timesteps to correct device beforehand
if torch.is_tensor(self.scheduler.timesteps): timesteps_tensor = self.scheduler.timesteps.to(self.device)
timesteps_tensor = self.scheduler.timesteps.to(self.device)
else:
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
......
...@@ -304,7 +304,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -304,7 +304,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
latents = init_latents latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0) t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
t_index = t_start + i t_index = t_start + i
......
...@@ -342,7 +342,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -342,7 +342,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latents = init_latents latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0) t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in tqdm(enumerate(timesteps)): for i, t in tqdm(enumerate(timesteps)):
t_index = t_start + i t_index = t_start + i
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps. - Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
- Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality. - Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are available in numpy, but can easily be transformed into PyTorch. - Schedulers are available in PyTorch and Jax.
## API ## API
......
...@@ -154,7 +154,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -154,7 +154,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1] self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
def _get_variance(self, timestep, prev_timestep): def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
...@@ -166,7 +166,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -166,7 +166,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def set_timesteps(self, num_inference_steps: int, **kwargs): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -183,7 +183,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -183,7 +183,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1] timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += offset self.timesteps += offset
def step( def step(
......
...@@ -142,11 +142,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -142,11 +142,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1] self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.variance_type = variance_type self.variance_type = variance_type
def set_timesteps(self, num_inference_steps: int): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -156,9 +156,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -156,9 +156,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange( timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1] )[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None): def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t = self.alphas_cumprod[t]
......
...@@ -97,10 +97,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -97,10 +97,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
# setable values # setable values
self.num_inference_steps: int = None self.num_inference_steps: int = None
self.timesteps: np.ndarray = None self.timesteps: np.IntTensor = None
self.schedule: torch.FloatTensor = None # sigma(t_i) self.schedule: torch.FloatTensor = None # sigma(t_i)
def set_timesteps(self, num_inference_steps: int): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -110,7 +110,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -110,7 +110,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
schedule = [ schedule = [
( (
self.config.sigma_max**2 self.config.sigma_max**2
...@@ -118,7 +119,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -118,7 +119,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
) )
for i in self.timesteps for i in self.timesteps
] ]
self.schedule = torch.tensor(schedule, dtype=torch.float32) self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
def add_noise_to_input( def add_noise_to_input(
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
......
...@@ -147,7 +147,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -147,7 +147,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.plms_timesteps = None self.plms_timesteps = None
self.timesteps = None self.timesteps = None
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -184,7 +184,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -184,7 +184,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
::-1 ::-1
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.ets = [] self.ets = []
self.counter = 0 self.counter = 0
......
...@@ -89,7 +89,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -89,7 +89,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): def set_timesteps(
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
):
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -101,7 +103,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -101,7 +103,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
""" """
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
def set_sigmas( def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
......
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import math import math
from typing import Union
import torch import torch
...@@ -52,8 +51,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -52,8 +51,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
self.discrete_sigmas = None self.discrete_sigmas = None
self.timesteps = None self.timesteps = None
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
def step_pred(self, score, x, t, generator=None): def step_pred(self, score, x, t, generator=None):
if self.timesteps is None: if self.timesteps is None:
......
...@@ -354,7 +354,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -354,7 +354,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5) scheduler.set_timesteps(5)
assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all() assert torch.equal(scheduler.timesteps, torch.LongTensor([801, 601, 401, 201, 1]))
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
...@@ -568,10 +568,12 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -568,10 +568,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(10) scheduler.set_timesteps(10)
assert np.equal( assert torch.equal(
scheduler.timesteps, scheduler.timesteps,
np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), torch.LongTensor(
).all() [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
),
)
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment