Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
194ed794
Unverified
Commit
194ed794
authored
Aug 16, 2022
by
Patrick von Platen
Committed by
GitHub
Aug 16, 2022
Browse files
[PNDM] Stable diffusion (#186)
* [PNDM] Stable diffusino * finish
parent
051b3463
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
13 deletions
+43
-13
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+41
-12
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+2
-1
No files found.
src/diffusers/schedulers/scheduling_pndm.py
View file @
194ed794
...
...
@@ -56,6 +56,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
tensor_format
=
"pt"
,
skip_prk_steps
=
False
,
):
if
beta_schedule
==
"linear"
:
...
...
@@ -88,6 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# setable values
self
.
num_inference_steps
=
None
self
.
_timesteps
=
np
.
arange
(
0
,
num_train_timesteps
)[::
-
1
].
copy
()
self
.
_offset
=
0
self
.
prk_timesteps
=
None
self
.
plms_timesteps
=
None
self
.
timesteps
=
None
...
...
@@ -95,17 +97,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
tensor_format
=
tensor_format
self
.
set_format
(
tensor_format
=
tensor_format
)
def
set_timesteps
(
self
,
num_inference_steps
):
def
set_timesteps
(
self
,
num_inference_steps
,
offset
=
0
):
self
.
num_inference_steps
=
num_inference_steps
self
.
_timesteps
=
list
(
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
)
self
.
_offset
=
offset
self
.
_timesteps
=
[
t
+
self
.
_offset
for
t
in
self
.
_timesteps
]
if
self
.
config
.
skip_prk_steps
:
# for some models like stable diffusion the prk steps can/should be skipped to
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
self
.
prk_timesteps
=
[]
self
.
plms_timesteps
=
list
(
reversed
(
self
.
_timesteps
[:
-
1
]
+
self
.
_timesteps
[
-
2
:
-
1
]
+
self
.
_timesteps
[
-
1
:]))
else
:
prk_timesteps
=
np
.
array
(
self
.
_timesteps
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
self
.
prk_timesteps
=
list
(
reversed
(
prk_timesteps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
plms_timesteps
=
list
(
reversed
(
self
.
_timesteps
[:
-
3
]))
prk_timesteps
=
np
.
array
(
self
.
_timesteps
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
self
.
prk_timesteps
=
list
(
reversed
(
prk_timesteps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
plms_timesteps
=
list
(
reversed
(
self
.
_timesteps
[:
-
3
]))
self
.
timesteps
=
self
.
prk_timesteps
+
self
.
plms_timesteps
self
.
counter
=
0
...
...
@@ -117,7 +129,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
):
if
self
.
counter
<
len
(
self
.
prk_timesteps
):
if
self
.
counter
<
len
(
self
.
prk_timesteps
)
and
not
self
.
config
.
skip_prk_steps
:
return
self
.
step_prk
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
else
:
return
self
.
step_plms
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
...
...
@@ -166,7 +178,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
"""
if
len
(
self
.
ets
)
<
3
:
if
not
self
.
config
.
skip_prk_steps
and
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
...
...
@@ -175,9 +187,26 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
)
prev_timestep
=
max
(
timestep
-
self
.
config
.
num_train_timesteps
//
self
.
num_inference_steps
,
0
)
self
.
ets
.
append
(
model_output
)
model_output
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
if
self
.
counter
!=
1
:
self
.
ets
.
append
(
model_output
)
else
:
prev_timestep
=
timestep
timestep
=
timestep
+
self
.
config
.
num_train_timesteps
//
self
.
num_inference_steps
if
len
(
self
.
ets
)
==
1
and
self
.
counter
==
0
:
model_output
=
model_output
self
.
cur_sample
=
sample
elif
len
(
self
.
ets
)
==
1
and
self
.
counter
==
1
:
model_output
=
(
model_output
+
self
.
ets
[
-
1
])
/
2
sample
=
self
.
cur_sample
self
.
cur_sample
=
None
elif
len
(
self
.
ets
)
==
2
:
model_output
=
(
3
*
self
.
ets
[
-
1
]
-
self
.
ets
[
-
2
])
/
2
elif
len
(
self
.
ets
)
==
3
:
model_output
=
(
23
*
self
.
ets
[
-
1
]
-
16
*
self
.
ets
[
-
2
]
+
5
*
self
.
ets
[
-
3
])
/
12
else
:
model_output
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
prev_sample
=
self
.
_get_prev_sample
(
sample
,
timestep
,
prev_timestep
,
model_output
)
self
.
counter
+=
1
...
...
@@ -197,8 +226,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
+
1
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
timestep_prev
+
1
]
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
+
1
-
self
.
_offset
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
timestep_prev
+
1
-
self
.
_offset
]
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
...
...
tests/test_modeling_utils.py
View file @
194ed794
...
...
@@ -843,6 +843,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Stable diffusion is suppused to run on GPU"
)
def
test_stable_diffusion
(
self
):
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1-diffusers"
)
prompt
=
"A painting of a squirrel eating a burger"
...
...
@@ -857,7 +858,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
512
,
512
,
3
)
expected_slice
=
np
.
array
([
0.8
9
8
,
0.91
94
,
0.91
,
0.89
55
,
0.9
15
,
0.91
9
,
0.9
233
,
0.9
307
,
0.88
87
])
expected_slice
=
np
.
array
([
0.88
87
,
0.91
5
,
0.91
,
0.89
4
,
0.9
09
,
0.91
2
,
0.9
19
,
0.9
25
,
0.88
3
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
@
slow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment