Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
2cbdc586
Unverified
Commit
2cbdc586
authored
Apr 09, 2023
by
Will Berman
Committed by
GitHub
Apr 09, 2023
Browse files
dynamic threshold sampling bug fixes and docs (#3003)
dynamic threshold sampling bug fix and docs
parent
dcfa6e1d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
206 additions
and
83 deletions
+206
-83
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+35
-13
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+35
-13
src/diffusers/schedulers/scheduling_deis_multistep.py
src/diffusers/schedulers/scheduling_deis_multistep.py
+33
-14
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+34
-14
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+34
-14
src/diffusers/schedulers/scheduling_unipc_multistep.py
src/diffusers/schedulers/scheduling_unipc_multistep.py
+34
-14
tests/schedulers/test_scheduler_dpm_multi.py
tests/schedulers/test_scheduler_dpm_multi.py
+1
-1
No files found.
src/diffusers/schedulers/scheduling_ddim.py
View file @
2cbdc586
...
@@ -201,15 +201,38 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -201,15 +201,38 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
"""
dynamic_max_val
=
(
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample
.
flatten
(
1
)
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.
abs
()
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.
clamp_min
(
self
.
config
.
sample_max_value
)
photorealism as well as better image-text alignment, especially when using very large guidance weights."
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
https://arxiv.org/abs/2205.11487
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
"""
...
@@ -315,14 +338,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -315,14 +338,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
)
)
# 4. Clip or threshold "predicted x_0"
# 4. Clip or threshold "predicted x_0"
if
self
.
config
.
clip_sample
:
if
self
.
config
.
thresholding
:
pred_original_sample
=
self
.
_threshold_sample
(
pred_original_sample
)
elif
self
.
config
.
clip_sample
:
pred_original_sample
=
pred_original_sample
.
clamp
(
pred_original_sample
=
pred_original_sample
.
clamp
(
-
self
.
config
.
clip_sample_range
,
self
.
config
.
clip_sample_range
-
self
.
config
.
clip_sample_range
,
self
.
config
.
clip_sample_range
)
)
if
self
.
config
.
thresholding
:
pred_original_sample
=
self
.
_threshold_sample
(
pred_original_sample
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance
=
self
.
_get_variance
(
timestep
,
prev_timestep
)
variance
=
self
.
_get_variance
(
timestep
,
prev_timestep
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
2cbdc586
...
@@ -241,15 +241,38 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -241,15 +241,38 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
variance
return
variance
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
"""
dynamic_max_val
=
(
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample
.
flatten
(
1
)
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.
abs
()
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.
clamp_min
(
self
.
config
.
sample_max_value
)
photorealism as well as better image-text alignment, especially when using very large guidance weights."
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
https://arxiv.org/abs/2205.11487
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
step
(
def
step
(
self
,
self
,
...
@@ -309,14 +332,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -309,14 +332,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)
)
# 3. Clip or threshold "predicted x_0"
# 3. Clip or threshold "predicted x_0"
if
self
.
config
.
clip_sample
:
if
self
.
config
.
thresholding
:
pred_original_sample
=
self
.
_threshold_sample
(
pred_original_sample
)
elif
self
.
config
.
clip_sample
:
pred_original_sample
=
pred_original_sample
.
clamp
(
pred_original_sample
=
pred_original_sample
.
clamp
(
-
self
.
config
.
clip_sample_range
,
self
.
config
.
clip_sample_range
-
self
.
config
.
clip_sample_range
,
self
.
config
.
clip_sample_range
)
)
if
self
.
config
.
thresholding
:
pred_original_sample
=
self
.
_threshold_sample
(
pred_original_sample
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff
=
(
alpha_prod_t_prev
**
(
0.5
)
*
current_beta_t
)
/
beta_prod_t
pred_original_sample_coeff
=
(
alpha_prod_t_prev
**
(
0.5
)
*
current_beta_t
)
/
beta_prod_t
...
...
src/diffusers/schedulers/scheduling_deis_multistep.py
View file @
2cbdc586
...
@@ -196,15 +196,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -196,15 +196,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
"""
dynamic_max_val
=
(
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample
.
flatten
(
1
)
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.
abs
()
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.
clamp_min
(
self
.
config
.
sample_max_value
)
photorealism as well as better image-text alignment, especially when using very large guidance weights."
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
https://arxiv.org/abs/2205.11487
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
convert_model_output
(
def
convert_model_output
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
...
@@ -236,11 +259,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -236,11 +259,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
)
)
if
self
.
config
.
thresholding
:
if
self
.
config
.
thresholding
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
orig_dtype
=
x0_pred
.
dtype
if
orig_dtype
not
in
[
torch
.
float
,
torch
.
double
]:
x0_pred
=
x0_pred
.
float
()
x0_pred
=
self
.
_threshold_sample
(
x0_pred
).
type
(
orig_dtype
)
if
self
.
config
.
algorithm_type
==
"deis"
:
if
self
.
config
.
algorithm_type
==
"deis"
:
alpha_t
,
sigma_t
=
self
.
alpha_t
[
timestep
],
self
.
sigma_t
[
timestep
]
alpha_t
,
sigma_t
=
self
.
alpha_t
[
timestep
],
self
.
sigma_t
[
timestep
]
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
2cbdc586
...
@@ -207,15 +207,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -207,15 +207,38 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
"""
dynamic_max_val
=
(
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample
.
flatten
(
1
)
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.
abs
()
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.
clamp_min
(
self
.
config
.
sample_max_value
)
photorealism as well as better image-text alignment, especially when using very large guidance weights."
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
https://arxiv.org/abs/2205.11487
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
convert_model_output
(
def
convert_model_output
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
...
@@ -256,11 +279,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -256,11 +279,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
)
if
self
.
config
.
thresholding
:
if
self
.
config
.
thresholding
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
orig_dtype
=
x0_pred
.
dtype
if
orig_dtype
not
in
[
torch
.
float
,
torch
.
double
]:
x0_pred
=
x0_pred
.
float
()
x0_pred
=
self
.
_threshold_sample
(
x0_pred
).
type
(
orig_dtype
)
return
x0_pred
return
x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
# DPM-Solver needs to solve an integral of the noise prediction model.
elif
self
.
config
.
algorithm_type
==
"dpmsolver"
:
elif
self
.
config
.
algorithm_type
==
"dpmsolver"
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
View file @
2cbdc586
...
@@ -239,15 +239,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -239,15 +239,38 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
"""
dynamic_max_val
=
(
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample
.
flatten
(
1
)
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.
abs
()
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.
clamp_min
(
self
.
config
.
sample_max_value
)
photorealism as well as better image-text alignment, especially when using very large guidance weights."
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
https://arxiv.org/abs/2205.11487
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
convert_model_output
(
def
convert_model_output
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
...
@@ -288,11 +311,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -288,11 +311,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
)
)
if
self
.
config
.
thresholding
:
if
self
.
config
.
thresholding
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
orig_dtype
=
x0_pred
.
dtype
if
orig_dtype
not
in
[
torch
.
float
,
torch
.
double
]:
x0_pred
=
x0_pred
.
float
()
x0_pred
=
self
.
_threshold_sample
(
x0_pred
).
type
(
orig_dtype
)
return
x0_pred
return
x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
# DPM-Solver needs to solve an integral of the noise prediction model.
elif
self
.
config
.
algorithm_type
==
"dpmsolver"
:
elif
self
.
config
.
algorithm_type
==
"dpmsolver"
:
...
...
src/diffusers/schedulers/scheduling_unipc_multistep.py
View file @
2cbdc586
...
@@ -212,15 +212,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -212,15 +212,38 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
"""
dynamic_max_val
=
(
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
sample
.
flatten
(
1
)
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
.
abs
()
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
.
quantile
(
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
.
clamp_min
(
self
.
config
.
sample_max_value
)
photorealism as well as better image-text alignment, especially when using very large guidance weights."
.
view
(
-
1
,
*
([
1
]
*
(
sample
.
ndim
-
1
)))
)
https://arxiv.org/abs/2205.11487
return
sample
.
clamp
(
-
dynamic_max_val
,
dynamic_max_val
)
/
dynamic_max_val
"""
dtype
=
sample
.
dtype
batch_size
,
channels
,
height
,
width
=
sample
.
shape
if
dtype
not
in
(
torch
.
float32
,
torch
.
float64
):
sample
=
sample
.
float
()
# upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample
=
sample
.
reshape
(
batch_size
,
channels
*
height
*
width
)
abs_sample
=
sample
.
abs
()
# "a certain percentile absolute pixel value"
s
=
torch
.
quantile
(
abs_sample
,
self
.
config
.
dynamic_thresholding_ratio
,
dim
=
1
)
s
=
torch
.
clamp
(
s
,
min
=
1
,
max
=
self
.
config
.
sample_max_value
)
# When clamped to min=1, equivalent to standard clipping to [-1, 1]
s
=
s
.
unsqueeze
(
1
)
# (batch_size, 1) because clamp will broadcast along dim=0
sample
=
torch
.
clamp
(
sample
,
-
s
,
s
)
/
s
# "we threshold xt0 to the range [-s, s] and then divide by s"
sample
=
sample
.
reshape
(
batch_size
,
channels
,
height
,
width
)
sample
=
sample
.
to
(
dtype
)
return
sample
def
convert_model_output
(
def
convert_model_output
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
int
,
sample
:
torch
.
FloatTensor
...
@@ -253,11 +276,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -253,11 +276,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
)
if
self
.
config
.
thresholding
:
if
self
.
config
.
thresholding
:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
x0_pred
=
self
.
_threshold_sample
(
x0_pred
)
orig_dtype
=
x0_pred
.
dtype
if
orig_dtype
not
in
[
torch
.
float
,
torch
.
double
]:
x0_pred
=
x0_pred
.
float
()
x0_pred
=
self
.
_threshold_sample
(
x0_pred
).
type
(
orig_dtype
)
return
x0_pred
return
x0_pred
else
:
else
:
if
self
.
config
.
prediction_type
==
"epsilon"
:
if
self
.
config
.
prediction_type
==
"epsilon"
:
...
...
tests/schedulers/test_scheduler_dpm_multi.py
View file @
2cbdc586
...
@@ -201,7 +201,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
...
@@ -201,7 +201,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
sample
=
self
.
full_loop
(
thresholding
=
True
,
dynamic_thresholding_ratio
=
0.87
,
sample_max_value
=
0.5
)
sample
=
self
.
full_loop
(
thresholding
=
True
,
dynamic_thresholding_ratio
=
0.87
,
sample_max_value
=
0.5
)
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
result_mean
=
torch
.
mean
(
torch
.
abs
(
sample
))
assert
abs
(
result_mean
.
item
()
-
0.6405
)
<
1e-3
assert
abs
(
result_mean
.
item
()
-
1.1364
)
<
1e-3
def
test_full_loop_with_v_prediction
(
self
):
def
test_full_loop_with_v_prediction
(
self
):
sample
=
self
.
full_loop
(
prediction_type
=
"v_prediction"
)
sample
=
self
.
full_loop
(
prediction_type
=
"v_prediction"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment