Unverified Commit 2cbdc586 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

dynamic threshold sampling bug fixes and docs (#3003)

dynamic threshold sampling bug fix and docs
parent dcfa6e1d
...@@ -201,15 +201,38 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -201,15 +201,38 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 """
dynamic_max_val = ( "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample.flatten(1) prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.abs() s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.quantile(self.config.dynamic_thresholding_ratio, dim=1) pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.clamp_min(self.config.sample_max_value) photorealism as well as better image-text alignment, especially when using very large guidance weights."
.view(-1, *([1] * (sample.ndim - 1)))
) https://arxiv.org/abs/2205.11487
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val """
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
...@@ -315,14 +338,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -315,14 +338,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
) )
# 4. Clip or threshold "predicted x_0" # 4. Clip or threshold "predicted x_0"
if self.config.clip_sample: if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp( pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range -self.config.clip_sample_range, self.config.clip_sample_range
) )
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(timestep, prev_timestep) variance = self._get_variance(timestep, prev_timestep)
......
...@@ -241,15 +241,38 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -241,15 +241,38 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 """
dynamic_max_val = ( "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample.flatten(1) prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.abs() s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.quantile(self.config.dynamic_thresholding_ratio, dim=1) pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.clamp_min(self.config.sample_max_value) photorealism as well as better image-text alignment, especially when using very large guidance weights."
.view(-1, *([1] * (sample.ndim - 1)))
) https://arxiv.org/abs/2205.11487
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val """
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def step( def step(
self, self,
...@@ -309,14 +332,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -309,14 +332,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
) )
# 3. Clip or threshold "predicted x_0" # 3. Clip or threshold "predicted x_0"
if self.config.clip_sample: if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp( pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range -self.config.clip_sample_range, self.config.clip_sample_range
) )
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
......
...@@ -196,15 +196,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -196,15 +196,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 """
dynamic_max_val = ( "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample.flatten(1) prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.abs() s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.quantile(self.config.dynamic_thresholding_ratio, dim=1) pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.clamp_min(self.config.sample_max_value) photorealism as well as better image-text alignment, especially when using very large guidance weights."
.view(-1, *([1] * (sample.ndim - 1)))
) https://arxiv.org/abs/2205.11487
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val """
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
...@@ -236,11 +259,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -236,11 +259,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
) )
if self.config.thresholding: if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 x0_pred = self._threshold_sample(x0_pred)
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
if self.config.algorithm_type == "deis": if self.config.algorithm_type == "deis":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
......
...@@ -207,15 +207,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -207,15 +207,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 """
dynamic_max_val = ( "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample.flatten(1) prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.abs() s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.quantile(self.config.dynamic_thresholding_ratio, dim=1) pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.clamp_min(self.config.sample_max_value) photorealism as well as better image-text alignment, especially when using very large guidance weights."
.view(-1, *([1] * (sample.ndim - 1)))
) https://arxiv.org/abs/2205.11487
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val """
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
...@@ -256,11 +279,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -256,11 +279,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
) )
if self.config.thresholding: if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 x0_pred = self._threshold_sample(x0_pred)
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
......
...@@ -239,15 +239,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -239,15 +239,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 """
dynamic_max_val = ( "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample.flatten(1) prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.abs() s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.quantile(self.config.dynamic_thresholding_ratio, dim=1) pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.clamp_min(self.config.sample_max_value) photorealism as well as better image-text alignment, especially when using very large guidance weights."
.view(-1, *([1] * (sample.ndim - 1)))
) https://arxiv.org/abs/2205.11487
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val """
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
...@@ -288,11 +311,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -288,11 +311,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
) )
if self.config.thresholding: if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 x0_pred = self._threshold_sample(x0_pred)
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
......
...@@ -212,15 +212,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -212,15 +212,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 """
dynamic_max_val = ( "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample.flatten(1) prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.abs() s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.quantile(self.config.dynamic_thresholding_ratio, dim=1) pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.clamp_min(self.config.sample_max_value) photorealism as well as better image-text alignment, especially when using very large guidance weights."
.view(-1, *([1] * (sample.ndim - 1)))
) https://arxiv.org/abs/2205.11487
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val """
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.to(dtype)
return sample
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
...@@ -253,11 +276,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -253,11 +276,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
) )
if self.config.thresholding: if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 x0_pred = self._threshold_sample(x0_pred)
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
return x0_pred return x0_pred
else: else:
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
......
...@@ -201,7 +201,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -201,7 +201,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5) sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.6405) < 1e-3 assert abs(result_mean.item() - 1.1364) < 1e-3
def test_full_loop_with_v_prediction(self): def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction") sample = self.full_loop(prediction_type="v_prediction")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment