Unverified Commit 55660cfb authored by clarencechen's avatar clarencechen Committed by GitHub
Browse files

Improve dynamic thresholding and extend to DDPM and DDIM Schedulers (#2528)



* Improve dynamic threshold

* Update code

* Add dynamic threshold to ddim and ddpm

* Encapsulate and leverage code copy mechanism

Update style

* Clean up DDPM/DDIM constructor arguments

* add test

* also add to unipc

---------
Co-authored-by: default avatarPeter Lin <peterlin9863@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 46bef6e3
...@@ -70,8 +70,9 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -70,8 +70,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately class_embed_type (`str`, *optional*, defaults to None):
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
......
...@@ -90,8 +90,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -90,8 +90,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately class_embed_type (`str`, *optional*, defaults to None):
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
......
...@@ -171,8 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -171,8 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately class_embed_type (`str`, *optional*, defaults to None):
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
......
...@@ -98,7 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -98,7 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, default `True`): set_alpha_to_one (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
...@@ -111,6 +113,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -111,6 +113,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -128,6 +139,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -128,6 +139,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -184,6 +199,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -184,6 +199,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -286,9 +313,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -286,9 +313,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
" `v_prediction`" " `v_prediction`"
) )
# 4. Clip "predicted x_0" # 4. Clip or threshold "predicted x_0"
if self.config.clip_sample: if self.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1) pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
......
...@@ -98,11 +98,22 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -98,11 +98,22 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
prediction_type (`str`, default `epsilon`, optional): prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -119,7 +130,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -119,7 +130,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
clip_sample_range: Optional[float] = 1.0, thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -226,6 +240,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -226,6 +240,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
...@@ -283,12 +308,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -283,12 +308,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
" `v_prediction` for the DDPMScheduler." " `v_prediction` for the DDPMScheduler."
) )
# 3. Clip "predicted x_0" # 3. Clip or threshold "predicted x_0"
if self.config.clip_sample: if self.config.clip_sample:
pred_original_sample = torch.clamp( pred_original_sample = pred_original_sample.clamp(
pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range -self.config.clip_sample_range, self.config.clip_sample_range
) )
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
......
...@@ -96,7 +96,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -96,7 +96,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487). (https://arxiv.org/abs/2205.11487).
sample_max_value (`float`, default `1.0`): sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid woks when `thresholding=True` the threshold value for dynamic thresholding. Valid only when `thresholding=True`
algorithm_type (`str`, default `deis`): algorithm_type (`str`, default `deis`):
the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
the future the future
...@@ -194,6 +194,18 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -194,6 +194,18 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order ] * self.config.solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -228,15 +240,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -228,15 +240,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
orig_dtype = x0_pred.dtype orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]: if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float() x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile( x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
if self.config.algorithm_type == "deis": if self.config.algorithm_type == "deis":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
......
...@@ -204,6 +204,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -204,6 +204,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order ] * self.config.solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -247,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -247,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
orig_dtype = x0_pred.dtype orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]: if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float() x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile( x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
......
...@@ -237,6 +237,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -237,6 +237,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sample = None self.sample = None
self.orders = self.get_order_list(num_inference_steps) self.orders = self.get_order_list(num_inference_steps)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -277,18 +289,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -277,18 +289,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
if self.config.thresholding: if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
dtype = x0_pred.dtype orig_dtype = x0_pred.dtype
dynamic_max_val = torch.quantile( if orig_dtype not in [torch.float, torch.double]:
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)).float(), x0_pred = x0_pred.float()
self.config.dynamic_thresholding_ratio, x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
dim=1,
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.to(dtype)
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
......
...@@ -109,7 +109,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -109,7 +109,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
""" """
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
...@@ -129,8 +130,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -129,8 +130,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional): sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation). initial noise scale value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). sigma_max (`float`, optional):
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
""" """
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
......
...@@ -116,7 +116,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -116,7 +116,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
""" """
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
...@@ -143,8 +144,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -143,8 +144,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional): sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation). initial noise scale value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). sigma_max (`float`, optional):
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
""" """
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
......
...@@ -210,6 +210,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -210,6 +210,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.solver_p: if self.solver_p:
self.solver_p.set_timesteps(num_inference_steps, device=device) self.solver_p.set_timesteps(num_inference_steps, device=device)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -245,15 +257,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -245,15 +257,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
orig_dtype = x0_pred.dtype orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]: if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float() x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile( x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
return x0_pred return x0_pred
else: else:
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
......
...@@ -647,6 +647,16 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -647,6 +647,16 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_prediction_type(self): def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]: for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
...@@ -791,6 +801,16 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -791,6 +801,16 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_time_indices(self): def test_time_indices(self):
for t in [1, 10, 49]: for t in [1, 10, 49]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
...@@ -1212,6 +1232,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -1212,6 +1232,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.3301) < 1e-3 assert abs(result_mean.item() - 0.3301) < 1e-3
def test_full_loop_no_noise_thres(self):
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.6405) < 1e-3
def test_full_loop_with_v_prediction(self): def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction") sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment