Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
da1f920e
Commit
da1f920e
authored
Jun 14, 2022
by
Patrick von Platen
Browse files
finalize pndm
parent
9b7e6f49
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
11 deletions
+12
-11
src/diffusers/pipelines/pipeline_pndm.py
src/diffusers/pipelines/pipeline_pndm.py
+4
-7
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+8
-4
No files found.
src/diffusers/pipelines/pipeline_pndm.py
View file @
da1f920e
...
...
@@ -28,7 +28,8 @@ class PNDM(DiffusionPipeline):
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
...
@@ -42,21 +43,17 @@ class PNDM(DiffusionPipeline):
image
=
image
.
to
(
torch_device
)
warmup_time_steps
=
self
.
noise_scheduler
.
get_warmup_time_steps
(
num_inference_steps
)
prev_image
=
image
for
t
in
tqdm
.
tqdm
(
range
(
len
(
warmup_time_steps
))):
t_orig
=
warmup_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
if
t
%
4
==
0
:
prev_image
=
image
image
=
self
.
noise_scheduler
.
step_warm_up
(
residual
,
prev_image
,
t
,
num_inference_steps
)
image
=
self
.
noise_scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
timesteps
=
self
.
noise_scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
timesteps
))):
t_orig
=
timesteps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
noise_scheduler
.
step
_plms
(
residual
,
image
,
t
,
num_inference_steps
)
return
image
src/diffusers/schedulers/scheduling_pndm.py
View file @
da1f920e
...
...
@@ -55,11 +55,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
set_format
(
tensor_format
=
tensor_format
)
# for now we only support F-PNDM, i.e. the runge-kutta method
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at equations (12) and (13) and the Algorithm 2.
self
.
pndm_order
=
4
# running values
self
.
cur_residual
=
0
self
.
cur_image
=
None
self
.
ets
=
[]
self
.
warmup_time_steps
=
{}
self
.
time_steps
=
{}
...
...
@@ -95,7 +98,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
self
.
time_steps
[
num_inference_steps
]
def
step_
warm_up
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
def
step_
prk
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
warmup_time_steps
=
self
.
get_warmup_time_steps
(
num_inference_steps
)
...
...
@@ -105,6 +108,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if
t
%
4
==
0
:
self
.
cur_residual
+=
1
/
6
*
residual
self
.
ets
.
append
(
residual
)
self
.
cur_image
=
image
elif
(
t
-
1
)
%
4
==
0
:
self
.
cur_residual
+=
1
/
3
*
residual
elif
(
t
-
2
)
%
4
==
0
:
...
...
@@ -113,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual
=
self
.
cur_residual
+
1
/
6
*
residual
self
.
cur_residual
=
0
return
self
.
transfer
(
image
,
t_prev
,
t_next
,
residual
)
return
self
.
transfer
(
self
.
cur_
image
,
t_prev
,
t_next
,
residual
)
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
def
step
_plms
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
timesteps
=
self
.
get_time_steps
(
num_inference_steps
)
t_prev
=
timesteps
[
t
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment