Unverified Commit 6b09f370 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[Scheduler design] The pragmatic approach (#719)

* init

* improve add_noise

* [debug start] run slow test

* [debug end]

* quick revert

* Add docstrings and warnings + API tests

* Make the warning less spammy
parent 726aba08
...@@ -57,7 +57,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -57,7 +57,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.unet model = self.unet
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
sample = sample.to(self.device) sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -281,9 +281,8 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -281,9 +281,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
# It's more optimized to move all timesteps to correct device beforehand # It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device) timesteps_tensor = self.scheduler.timesteps.to(self.device)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas # scale the initial noise by the standard deviation required by the scheduler
if isinstance(self.scheduler, LMSDiscreteScheduler): latents = latents * self.scheduler.init_noise_sigma
latents = latents * self.scheduler.sigmas[0]
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
...@@ -297,10 +296,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -297,10 +296,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
for i, t in enumerate(self.progress_bar(timesteps_tensor)): for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
...@@ -311,10 +307,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -311,10 +307,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
......
...@@ -226,13 +226,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -226,13 +226,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
offset = self.scheduler.config.get("steps_offset", 0) offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps) init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor( timesteps = self.scheduler.timesteps[-init_timestep]
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
# add noise to latents using the timesteps # add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device) noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
...@@ -310,16 +306,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -310,16 +306,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps[t_start:].to(self.device) timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
...@@ -330,10 +319,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -330,10 +319,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
......
...@@ -260,13 +260,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -260,13 +260,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
offset = self.scheduler.config.get("steps_offset", 0) offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps) init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor( timesteps = self.scheduler.timesteps[-init_timestep]
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
# add noise to latents using the timesteps # add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device) noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
...@@ -348,13 +344,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -348,13 +344,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps[t_start:].to(self.device) timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in tqdm(enumerate(timesteps)): for i, t in tqdm(enumerate(timesteps)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
...@@ -365,14 +357,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -365,14 +357,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample # masking
# masking init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask)) latents = (init_latents_proper * mask) + (latents * (1 - mask))
......
...@@ -147,9 +147,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -147,9 +147,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas latents = latents * self.scheduler.init_noise_sigma
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
...@@ -163,10 +161,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -163,10 +161,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
...@@ -180,11 +175,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -180,11 +175,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = np.array(latents) latents = np.array(latents)
# call the callback, if provided # call the callback, if provided
......
...@@ -69,7 +69,7 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -69,7 +69,7 @@ class KarrasVePipeline(DiffusionPipeline):
model = self.unet model = self.unet
# sample x_0 ~ N(0, sigma_0^2 * I) # sample x_0 ~ N(0, sigma_0^2 * I)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max sample = torch.randn(*shape) * self.scheduler.init_noise_sigma
sample = sample.to(self.device) sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -152,10 +152,27 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -152,10 +152,27 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# whether we use the final alpha of the "non-previous" one. # whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def _get_variance(self, timestep, prev_timestep): def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
......
...@@ -140,12 +140,29 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -140,12 +140,29 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0) self.one = torch.tensor(1.0)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.variance_type = variance_type self.variance_type = variance_type
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -95,11 +95,28 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -95,11 +95,28 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
take_from=kwargs, take_from=kwargs,
) )
# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max
# setable values # setable values
self.num_inference_steps: int = None self.num_inference_steps: int = None
self.timesteps: np.IntTensor = None self.timesteps: np.IntTensor = None
self.schedule: torch.FloatTensor = None # sigma(t_i) self.schedule: torch.FloatTensor = None # sigma(t_i)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -102,11 +102,36 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -102,11 +102,36 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas) self.sigmas = torch.from_numpy(sigmas)
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps) self.timesteps = torch.from_numpy(timesteps)
self.derivatives = [] self.derivatives = []
self.is_scale_input_called = False
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
Returns:
`torch.FloatTensor`: scaled input sample
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample
def get_lms_coefficient(self, order, t, current_order): def get_lms_coefficient(self, order, t, current_order):
""" """
...@@ -154,7 +179,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -154,7 +179,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor, sample: torch.FloatTensor,
order: int = 4, order: int = 4,
return_dict: bool = True, return_dict: bool = True,
...@@ -165,7 +190,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -165,7 +190,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`float`): current timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
order: coefficient for multi-step inference. order: coefficient for multi-step inference.
...@@ -177,7 +202,21 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -177,7 +202,21 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
When returning a tuple, the first element is the sample tensor. When returning a tuple, the first element is the sample tensor.
""" """
sigma = self.sigmas[timestep] if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
warnings.warn(
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
"Make sure to pass one of the `scheduler.timesteps`"
)
if not self.is_scale_input_called:
warnings.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output pred_original_sample = sample - sigma * model_output
...@@ -189,8 +228,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,8 +228,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.derivatives.pop(0) self.derivatives.pop(0)
# 3. Compute linear multistep coefficients # 3. Compute linear multistep coefficients
order = min(timestep + 1, order) order = min(step_index + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)] lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
# 4. Compute previous sample based on the derivatives path # 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum( prev_sample = sample + sum(
...@@ -206,12 +245,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -206,12 +245,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor, timesteps: torch.FloatTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
sigmas = self.sigmas.to(original_samples.device) sigmas = self.sigmas.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[timesteps].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1) sigma = sigma.unsqueeze(-1)
......
...@@ -129,6 +129,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -129,6 +129,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# For now we only support F-PNDM, i.e. the runge-kutta method # For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2. # mainly at formula (9), (12), (13) and the Algorithm 2.
...@@ -342,6 +345,19 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -342,6 +345,19 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9) # this function computes x_(t−δ) using the formula of (9)
......
...@@ -84,11 +84,28 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -84,11 +84,28 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
take_from=kwargs, take_from=kwargs,
) )
# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max
# setable values # setable values
self.timesteps = None self.timesteps = None
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps( def set_timesteps(
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
): ):
......
...@@ -201,7 +201,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -201,7 +201,7 @@ class SchedulerCommonTest(unittest.TestCase):
) )
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", 50)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -226,6 +226,27 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -226,6 +226,27 @@ class SchedulerCommonTest(unittest.TestCase):
recursive_check(outputs_tuple, outputs_dict) recursive_check(outputs_tuple, outputs_dict)
def test_scheduler_public_api(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
self.assertTrue(
hasattr(scheduler, "init_noise_sigma"),
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
)
self.assertTrue(
hasattr(scheduler, "scale_model_input"),
f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`",
)
self.assertTrue(
hasattr(scheduler, "step"),
f"{scheduler_class} does not implement a required class method `step(...)`",
)
sample = self.dummy_sample
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
class DDPMSchedulerTest(SchedulerCommonTest): class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,) scheduler_classes = (DDPMScheduler,)
...@@ -865,14 +886,14 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -865,14 +886,14 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps) scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.sigmas[0] sample = self.dummy_sample_deter * scheduler.init_noise_sigma
for i, t in enumerate(scheduler.timesteps): for i, t in enumerate(scheduler.timesteps):
sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5) sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t) model_output = model(sample, t)
output = scheduler.step(model_output, i, sample) output = scheduler.step(model_output, t, sample)
sample = output.prev_sample sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment