Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b2274ece
Commit
b2274ece
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
finish pndm scheduler
parent
de22d4cd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
233 additions
and
41 deletions
+233
-41
src/diffusers/pipelines/pipeline_pndm.py
src/diffusers/pipelines/pipeline_pndm.py
+3
-3
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+75
-31
tests/test_scheduler.py
tests/test_scheduler.py
+155
-7
No files found.
src/diffusers/pipelines/pipeline_pndm.py
View file @
b2274ece
...
...
@@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline):
)
image
=
image
.
to
(
torch_device
)
warmup
_time_steps
=
self
.
noise_scheduler
.
get_
warmup
_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
warmup
_time_steps
))):
t_orig
=
warmup
_time_steps
[
t
]
prk
_time_steps
=
self
.
noise_scheduler
.
get_
prk
_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
prk
_time_steps
))):
t_orig
=
prk
_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
image
=
self
.
noise_scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
b2274ece
...
...
@@ -56,15 +56,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at
equations
(12)
and
(13) and the Algorithm 2.
# mainly at
formula (9),
(12)
,
(13) and the Algorithm 2.
self
.
pndm_order
=
4
# running values
self
.
cur_residual
=
0
self
.
cur_sample
=
None
self
.
ets
=
[]
self
.
warmup
_time_steps
=
{}
self
.
prk
_time_steps
=
{}
self
.
time_steps
=
{}
self
.
set_prk_mode
()
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
...
...
@@ -77,18 +78,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
self
.
one
return
self
.
alphas_cumprod
[
time_step
]
def
get_
warmup
_time_steps
(
self
,
num_inference_steps
):
if
num_inference_steps
in
self
.
warmup
_time_steps
:
return
self
.
warmup
_time_steps
[
num_inference_steps
]
def
get_
prk
_time_steps
(
self
,
num_inference_steps
):
if
num_inference_steps
in
self
.
prk
_time_steps
:
return
self
.
prk
_time_steps
[
num_inference_steps
]
inference_step_times
=
list
(
range
(
0
,
self
.
config
.
timesteps
,
self
.
config
.
timesteps
//
num_inference_steps
))
warmup
_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
prk
_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
config
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
self
.
warmup
_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup
_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
prk
_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
prk
_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
return
self
.
warmup
_time_steps
[
num_inference_steps
]
return
self
.
prk
_time_steps
[
num_inference_steps
]
def
get_time_steps
(
self
,
num_inference_steps
):
if
num_inference_steps
in
self
.
time_steps
:
...
...
@@ -99,12 +100,25 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
self
.
time_steps
[
num_inference_steps
]
def
set_prk_mode
(
self
):
self
.
mode
=
"prk"
def
set_plms_mode
(
self
):
self
.
mode
=
"plms"
def
step
(
self
,
*
args
,
**
kwargs
):
if
self
.
mode
==
"prk"
:
return
self
.
step_prk
(
*
args
,
**
kwargs
)
if
self
.
mode
==
"plms"
:
return
self
.
step_plms
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"mode
{
self
.
mode
}
does not exist."
)
def
step_prk
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
warmup_time_steps
=
self
.
get_warmup_time_steps
(
num_inference_steps
)
prk_time_steps
=
self
.
get_prk_time_steps
(
num_inference_steps
)
t_
prev
=
warmup
_time_steps
[
t
//
4
*
4
]
t_
next
=
warmup
_time_steps
[
min
(
t
+
1
,
len
(
warmup
_time_steps
)
-
1
)]
t_
orig
=
prk
_time_steps
[
t
//
4
*
4
]
t_
orig_prev
=
prk
_time_steps
[
min
(
t
+
1
,
len
(
prk
_time_steps
)
-
1
)]
if
t
%
4
==
0
:
self
.
cur_residual
+=
1
/
6
*
residual
...
...
@@ -118,33 +132,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual
=
self
.
cur_residual
+
1
/
6
*
residual
self
.
cur_residual
=
0
return
self
.
transfer
(
self
.
cur_sample
,
t_prev
,
t_next
,
residual
)
# cur_sample should not be `None`
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
return
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
residual
)
def
step_plms
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
if
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
"for more information."
)
timesteps
=
self
.
get_time_steps
(
num_inference_steps
)
t_
prev
=
timesteps
[
t
]
t_
next
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
t_
orig
=
timesteps
[
t
]
t_
orig_prev
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
self
.
ets
.
append
(
residual
)
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
self
.
transfer
(
sample
,
t_prev
,
t_next
,
residual
)
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
# TODO(Patrick): clean up to be compatible with numpy and give better names
alphas_cump
=
self
.
alphas_cumprod
.
to
(
x
.
device
)
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
x_delta
=
(
at_next
-
at
)
*
(
(
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
x_next
=
x
+
x_delta
return
x_next
return
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
residual
)
def
get_prev_sample
(
self
,
sample
,
t_orig
,
t_orig_prev
,
residual
):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
# Notation (<variable name> -> <name in paper>
# alpha_prod_t -> α_t
# alpha_prod_t_prev -> α_(t−δ)
# beta_prod_t -> (1 - α_t)
# beta_prod_t_prev -> (1 - α_(t−δ))
# sample -> x_t
# residual -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t
=
self
.
get_alpha_prod
(
t_orig
+
1
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
t_orig_prev
+
1
)
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
# corresponds to (α_(t−δ) - α_t) divided by
# denominator of x_t in formula (9) and plus 1
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
# sqrt(α_(t−δ)) / sqrt(α_t))
sample_coeff
=
(
alpha_prod_t_prev
/
alpha_prod_t
)
**
(
0.5
)
# corresponds to denominator of e_θ(x_t, t) in formula (9)
residual_denom_coeff
=
alpha_prod_t
*
beta_prod_t_prev
**
(
0.5
)
+
(
alpha_prod_t
*
beta_prod_t
*
alpha_prod_t_prev
)
**
(
0.5
)
# full formula (9)
prev_sample
=
sample_coeff
*
sample
-
(
alpha_prod_t_prev
-
alpha_prod_t
)
*
residual
/
residual_denom_coeff
return
prev_sample
def
__len__
(
self
):
return
self
.
config
.
timesteps
tests/test_scheduler.py
View file @
b2274ece
...
...
@@ -20,7 +20,7 @@ import unittest
import
numpy
as
np
import
torch
from
diffusers
import
DDIMScheduler
,
DDPMScheduler
from
diffusers
import
DDIMScheduler
,
DDPMScheduler
,
PNDMScheduler
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
...
@@ -90,10 +90,10 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs
.
update
(
forward_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
image
=
self
.
dummy_image
residual
=
0.1
*
image
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
...
...
@@ -159,7 +159,7 @@ class SchedulerCommonTest(unittest.TestCase):
output
=
scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
image_pt
,
1
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-
5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-
4
,
"Scheduler outputs are not identical"
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
...
...
@@ -237,8 +237,8 @@ class DDPMSchedulerTest(SchedulerCommonTest):
result_sum
=
np
.
sum
(
np
.
abs
(
image
))
result_mean
=
np
.
mean
(
np
.
abs
(
image
))
assert
result_sum
.
item
()
-
732.9947
<
1e-
3
assert
result_mean
.
item
()
-
0.9544
<
1e-3
assert
abs
(
result_sum
.
item
()
-
732.9947
)
<
1e-
2
assert
abs
(
result_mean
.
item
()
-
0.9544
)
<
1e-3
class
DDIMSchedulerTest
(
SchedulerCommonTest
):
...
...
@@ -325,5 +325,153 @@ class DDIMSchedulerTest(SchedulerCommonTest):
result_sum
=
np
.
sum
(
np
.
abs
(
image
))
result_mean
=
np
.
mean
(
np
.
abs
(
image
))
assert
result_sum
.
item
()
-
270.6214
<
1e-3
assert
result_mean
.
item
()
-
0.3524
<
1e-3
assert
abs
(
result_sum
.
item
()
-
270.6214
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.3524
)
<
1e-3
class
PNDMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
PNDMScheduler
,)
forward_default_kwargs
=
((
"num_inference_steps"
,
50
),)
def
get_scheduler_config
(
self
,
**
kwargs
):
config
=
{
"timesteps"
:
1000
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
}
config
.
update
(
**
kwargs
)
return
config
def
check_over_configs_pmls
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
image
=
self
.
dummy_image
residual
=
0.1
*
image
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
set_plms_mode
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
check_over_forward_pmls
(
self
,
time_step
=
0
,
**
forward_kwargs
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
image
=
self
.
dummy_image
residual
=
0.1
*
image
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
set_plms_mode
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_timesteps
(
self
):
for
timesteps
in
[
100
,
1000
]:
self
.
check_over_configs
(
timesteps
=
timesteps
)
def
test_timesteps_pmls
(
self
):
for
timesteps
in
[
100
,
1000
]:
self
.
check_over_configs_pmls
(
timesteps
=
timesteps
)
def
test_betas
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.0001
,
0.001
,
0.01
],
[
0.002
,
0.02
,
0.2
]):
self
.
check_over_configs
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_betas_pmls
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.0001
,
0.001
,
0.01
],
[
0.002
,
0.02
,
0.2
]):
self
.
check_over_configs_pmls
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_schedules
(
self
):
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_schedules_pmls
(
self
):
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_time_indices
(
self
):
for
t
in
[
1
,
5
,
10
]:
self
.
check_over_forward
(
time_step
=
t
)
def
test_time_indices_pmls
(
self
):
for
t
in
[
1
,
5
,
10
]:
self
.
check_over_forward_pmls
(
time_step
=
t
)
def
test_inference_steps
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
5
,
10
],
[
10
,
50
,
100
]):
self
.
check_over_forward
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
def
test_inference_steps_pmls
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
5
,
10
],
[
10
,
50
,
100
]):
self
.
check_over_forward_pmls
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
def
test_inference_pmls_no_past_residuals
(
self
):
with
self
.
assertRaises
(
ValueError
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_plms_mode
()
scheduler
.
step
(
self
.
dummy_image
,
self
.
dummy_image
,
1
,
50
)
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
num_inference_steps
=
10
model
=
self
.
dummy_model
()
image
=
self
.
dummy_image_deter
prk_time_steps
=
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
prk_time_steps
)):
t_orig
=
prk_time_steps
[
t
]
residual
=
model
(
image
,
t_orig
)
image
=
scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
timesteps
=
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
timesteps
)):
t_orig
=
timesteps
[
t
]
residual
=
model
(
image
,
t_orig
)
image
=
scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
result_sum
=
np
.
sum
(
np
.
abs
(
image
))
result_mean
=
np
.
mean
(
np
.
abs
(
image
))
assert
abs
(
result_sum
.
item
()
-
199.1169
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.2593
)
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment