Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
2cbdc586
Unverified
Commit
2cbdc586
authored
Apr 09, 2023
by
Will Berman
Committed by
GitHub
Apr 09, 2023
Browse files
dynamic threshold sampling bug fixes and docs (#3003)
dynamic threshold sampling bug fix and docs
parent
dcfa6e1d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
206 additions
and
83 deletions
+206
-83
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+35
-13
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+35
-13
src/diffusers/schedulers/scheduling_deis_multistep.py
src/diffusers/schedulers/scheduling_deis_multistep.py
+33
-14
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+34
-14
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+34
-14
src/diffusers/schedulers/scheduling_unipc_multistep.py
src/diffusers/schedulers/scheduling_unipc_multistep.py
+34
-14
tests/schedulers/test_scheduler_dpm_multi.py
tests/schedulers/test_scheduler_dpm_multi.py
+1
-1
No files found.
src/diffusers/schedulers/scheduling_ddim.py
View file @
2cbdc586
...
...
@@ -201,15 +201,38 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val
=
(
sample
.
flatten
(
1
)
.
abs
()
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
.
clamp_min
(
self
.
config
.
sample_max_value
)
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
...
...
@@ -315,14 +338,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
)
# 4. Clip or threshold "predicted x_0"
if
self
.
config
.
clip_sample
:
if
self
.
config
.
thresholding
:
pred_original_sample
=
self
.
_threshold_sample
(
pred_original_sample
)
elif
self
.
config
.
clip_sample
:
pred_original_sample
=
pred_original_sample
.
clamp
(
-
self
.
config
.
clip_sample_range
,
self
.
config
.
clip_sample_range
)
if
self
.
config
.
thresholding
:
pred_original_sample
=
self
.
_threshold_sample
(
pred_original_sample
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance
=
self
.
_get_variance
(
timestep
,
prev_timestep
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
2cbdc586
...
...
@@ -241,15 +241,38 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
variance
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val
=
(
sample
.
flatten
(
1
)
.
abs
()
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
.
clamp_min
(
self
.
config
.
sample_max_value
)
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
step
(
self
,
...
...
@@ -309,14 +332,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)
# 3. Clip or threshold "predicted x_0"
if
self
.
config
.
clip_sample
:
if
self
.
config
.
thresholding
:
pred_original_sample
=
self
.
_threshold_sample
(
pred_original_sample
)
elif
self
.
config
.
clip_sample
:
pred_original_sample
=
pred_original_sample
.
clamp
(
-
self
.
config
.
clip_sample_range
,
self
.
config
.
clip_sample_range
)
if
self
.
config
.
thresholding
:
pred_original_sample
=
self
.
_threshold_sample
(
pred_original_sample
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff
=
(
alpha_prod_t_prev
**
(
0.5
)
*
current_beta_t
)
/
beta_prod_t
...
...
src/diffusers/schedulers/scheduling_deis_multistep.py
View file @
2cbdc586
...
...
@@ -196,15 +196,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val
=
(
sample
.
flatten
(
1
)
.
abs
()
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
.
clamp_min
(
self
.
config
.
sample_max_value
)
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
convert_model_output
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
...
...
@@ -236,11 +259,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
)
if
self
.
config
.
thresholding
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype
=
x0_pred
.
dtype
if
orig_dtype
not
in
[
torch
.
float
,
torch
.
double
]:
x0_pred
=
x0_pred
.
float
()
x0_pred
=
self
.
_threshold_sample
(
x0_pred
).
type
(
orig_dtype
)
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
if
self
.
config
.
algorithm_type
==
"deis"
:
alpha_t
,
sigma_t
=
self
.
alpha_t
[
timestep
],
self
.
sigma_t
[
timestep
]
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
2cbdc586
...
...
@@ -207,15 +207,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val
=
(
sample
.
flatten
(
1
)
.
abs
()
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
.
clamp_min
(
self
.
config
.
sample_max_value
)
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
convert_model_output
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
...
...
@@ -256,11 +279,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
if
self
.
config
.
thresholding
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype
=
x0_pred
.
dtype
if
orig_dtype
not
in
[
torch
.
float
,
torch
.
double
]:
x0_pred
=
x0_pred
.
float
()
x0_pred
=
self
.
_threshold_sample
(
x0_pred
).
type
(
orig_dtype
)
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
return
x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif
self
.
config
.
algorithm_type
==
"dpmsolver"
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
View file @
2cbdc586
...
...
@@ -239,15 +239,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val
=
(
sample
.
flatten
(
1
)
.
abs
()
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
.
clamp_min
(
self
.
config
.
sample_max_value
)
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
convert_model_output
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
...
...
@@ -288,11 +311,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
)
if
self
.
config
.
thresholding
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype
=
x0_pred
.
dtype
if
orig_dtype
not
in
[
torch
.
float
,
torch
.
double
]:
x0_pred
=
x0_pred
.
float
()
x0_pred
=
self
.
_threshold_sample
(
x0_pred
).
type
(
orig_dtype
)
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
return
x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif
self
.
config
.
algorithm_type
==
"dpmsolver"
:
...
...
src/diffusers/schedulers/scheduling_unipc_multistep.py
View file @
2cbdc586
...
...
@@ -212,15 +212,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val
=
(
sample
.
flatten
(
1
)
.
abs
()
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
.
clamp_min
(
self
.
config
.
sample_max_value
)
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
convert_model_output
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
...
...
@@ -253,11 +276,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
if
self
.
config
.
thresholding
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype
=
x0_pred
.
dtype
if
orig_dtype
not
in
[
torch
.
float
,
torch
.
double
]:
x0_pred
=
x0_pred
.
float
()
x0_pred
=
self
.
_threshold_sample
(
x0_pred
).
type
(
orig_dtype
)
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
return
x0_pred
else
:
if
self
.
config
.
prediction_type
==
"epsilon"
:
...
...
tests/schedulers/test_scheduler_dpm_multi.py
View file @
2cbdc586
...
...
@@ -201,7 +201,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
sample
=
self
.
full_loop
(
thresholding
=
True
,
dynamic_thresholding_ratio
=
0.87
,
sample_max_value
=
0.5
)
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
assert
abs
(
result_mean
.
item
()
-
0.6405
)
<
1e-3
assert
abs
(
result_mean
.
item
()
-
1.1364
)
<
1e-3
def
test_full_loop_with_v_prediction
(
self
):
sample
=
self
.
full_loop
(
prediction_type
=
"v_prediction"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment