Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
194ed794
Unverified
Commit
194ed794
authored
Aug 16, 2022
by
Patrick von Platen
Committed by
GitHub
Aug 16, 2022
Browse files
[PNDM] Stable diffusion (#186)
* [PNDM] Stable diffusino * finish
parent
051b3463
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
13 deletions
+43
-13
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+41
-12
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+2
-1
No files found.
src/diffusers/schedulers/scheduling_pndm.py
View file @
194ed794
...
@@ -56,6 +56,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -56,6 +56,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end
=
0.02
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
beta_schedule
=
"linear"
,
tensor_format
=
"pt"
,
tensor_format
=
"pt"
,
skip_prk_steps
=
False
,
):
):
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
...
@@ -88,6 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -88,6 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# setable values
# setable values
self
.
num_inference_steps
=
None
self
.
num_inference_steps
=
None
self
.
_timesteps
=
np
.
arange
(
0
,
num_train_timesteps
)[::
-
1
].
copy
()
self
.
_timesteps
=
np
.
arange
(
0
,
num_train_timesteps
)[::
-
1
].
copy
()
self
.
_offset
=
0
self
.
prk_timesteps
=
None
self
.
prk_timesteps
=
None
self
.
plms_timesteps
=
None
self
.
plms_timesteps
=
None
self
.
timesteps
=
None
self
.
timesteps
=
None
...
@@ -95,17 +97,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -95,17 +97,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
tensor_format
=
tensor_format
self
.
tensor_format
=
tensor_format
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
def
set_timesteps
(
self
,
num_inference_steps
):
def
set_timesteps
(
self
,
num_inference_steps
,
offset
=
0
):
self
.
num_inference_steps
=
num_inference_steps
self
.
num_inference_steps
=
num_inference_steps
self
.
_timesteps
=
list
(
self
.
_timesteps
=
list
(
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
)
)
self
.
_offset
=
offset
self
.
_timesteps
=
[
t
+
self
.
_offset
for
t
in
self
.
_timesteps
]
if
self
.
config
.
skip_prk_steps
:
# for some models like stable diffusion the prk steps can/should be skipped to
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
self
.
prk_timesteps
=
[]
self
.
plms_timesteps
=
list
(
reversed
(
self
.
_timesteps
[:
-
1
]
+
self
.
_timesteps
[
-
2
:
-
1
]
+
self
.
_timesteps
[
-
1
:]))
else
:
prk_timesteps
=
np
.
array
(
self
.
_timesteps
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
self
.
prk_timesteps
=
list
(
reversed
(
prk_timesteps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
plms_timesteps
=
list
(
reversed
(
self
.
_timesteps
[:
-
3
]))
prk_timesteps
=
np
.
array
(
self
.
_timesteps
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
self
.
prk_timesteps
=
list
(
reversed
(
prk_timesteps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
plms_timesteps
=
list
(
reversed
(
self
.
_timesteps
[:
-
3
]))
self
.
timesteps
=
self
.
prk_timesteps
+
self
.
plms_timesteps
self
.
timesteps
=
self
.
prk_timesteps
+
self
.
plms_timesteps
self
.
counter
=
0
self
.
counter
=
0
...
@@ -117,7 +129,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -117,7 +129,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
timestep
:
int
,
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
):
):
if
self
.
counter
<
len
(
self
.
prk_timesteps
):
if
self
.
counter
<
len
(
self
.
prk_timesteps
)
and
not
self
.
config
.
skip_prk_steps
:
return
self
.
step_prk
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
return
self
.
step_prk
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
else
:
else
:
return
self
.
step_plms
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
return
self
.
step_plms
(
model_output
=
model_output
,
timestep
=
timestep
,
sample
=
sample
)
...
@@ -166,7 +178,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -166,7 +178,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
times to approximate the solution.
"""
"""
if
len
(
self
.
ets
)
<
3
:
if
not
self
.
config
.
skip_prk_steps
and
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"in 'prk' mode for at least 12 iterations "
...
@@ -175,9 +187,26 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -175,9 +187,26 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
)
)
prev_timestep
=
max
(
timestep
-
self
.
config
.
num_train_timesteps
//
self
.
num_inference_steps
,
0
)
prev_timestep
=
max
(
timestep
-
self
.
config
.
num_train_timesteps
//
self
.
num_inference_steps
,
0
)
self
.
ets
.
append
(
model_output
)
model_output
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
if
self
.
counter
!=
1
:
self
.
ets
.
append
(
model_output
)
else
:
prev_timestep
=
timestep
timestep
=
timestep
+
self
.
config
.
num_train_timesteps
//
self
.
num_inference_steps
if
len
(
self
.
ets
)
==
1
and
self
.
counter
==
0
:
model_output
=
model_output
self
.
cur_sample
=
sample
elif
len
(
self
.
ets
)
==
1
and
self
.
counter
==
1
:
model_output
=
(
model_output
+
self
.
ets
[
-
1
])
/
2
sample
=
self
.
cur_sample
self
.
cur_sample
=
None
elif
len
(
self
.
ets
)
==
2
:
model_output
=
(
3
*
self
.
ets
[
-
1
]
-
self
.
ets
[
-
2
])
/
2
elif
len
(
self
.
ets
)
==
3
:
model_output
=
(
23
*
self
.
ets
[
-
1
]
-
16
*
self
.
ets
[
-
2
]
+
5
*
self
.
ets
[
-
3
])
/
12
else
:
model_output
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
prev_sample
=
self
.
_get_prev_sample
(
sample
,
timestep
,
prev_timestep
,
model_output
)
prev_sample
=
self
.
_get_prev_sample
(
sample
,
timestep
,
prev_timestep
,
model_output
)
self
.
counter
+=
1
self
.
counter
+=
1
...
@@ -197,8 +226,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -197,8 +226,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# sample -> x_t
# model_output -> e_θ(x_t, t)
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
# prev_sample -> x_(t−δ)
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
+
1
]
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
+
1
-
self
.
_offset
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
timestep_prev
+
1
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
timestep_prev
+
1
-
self
.
_offset
]
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
...
...
tests/test_modeling_utils.py
View file @
194ed794
...
@@ -843,6 +843,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -843,6 +843,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Stable diffusion is suppused to run on GPU"
)
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Stable diffusion is suppused to run on GPU"
)
def
test_stable_diffusion
(
self
):
def
test_stable_diffusion
(
self
):
# make sure here that pndm scheduler skips prk
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1-diffusers"
)
sd_pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-1-diffusers"
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
...
@@ -857,7 +858,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -857,7 +858,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
512
,
512
,
3
)
assert
image
.
shape
==
(
1
,
512
,
512
,
3
)
expected_slice
=
np
.
array
([
0.8
9
8
,
0.91
94
,
0.91
,
0.89
55
,
0.9
15
,
0.91
9
,
0.9
233
,
0.9
307
,
0.88
87
])
expected_slice
=
np
.
array
([
0.88
87
,
0.91
5
,
0.91
,
0.89
4
,
0.9
09
,
0.91
2
,
0.9
19
,
0.9
25
,
0.88
3
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
@
slow
@
slow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment