Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
559b8cbf
Commit
559b8cbf
authored
Jun 14, 2022
by
Patrick von Platen
Browse files
finish pndm
parent
7d8bf1a9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
82 deletions
+66
-82
src/diffusers/pipelines/pipeline_pndm.py
src/diffusers/pipelines/pipeline_pndm.py
+13
-49
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+53
-33
No files found.
src/diffusers/pipelines/pipeline_pndm.py
View file @
559b8cbf
...
@@ -32,9 +32,6 @@ class PNDM(DiffusionPipeline):
...
@@ -32,9 +32,6 @@ class PNDM(DiffusionPipeline):
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
...
@@ -44,55 +41,22 @@ class PNDM(DiffusionPipeline):
...
@@ -44,55 +41,22 @@ class PNDM(DiffusionPipeline):
)
)
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
seq
=
list
(
inference_step
_time
s
)
warmup_time_steps
=
self
.
noise_scheduler
.
get_warmup_time_steps
(
num_
inference_steps
)
seq_next
=
[
-
1
]
+
list
(
seq
[:
-
1
])
prev_image
=
image
model
=
self
.
unet
for
t
in
tqdm
.
tqdm
(
range
(
len
(
warmup_time_steps
))):
t_orig
=
warmup_time_steps
[
t
]
warmup_time_steps
=
list
(
reversed
([(
t
+
5
)
//
10
*
10
for
t
in
range
(
seq
[
-
4
],
seq
[
-
1
],
5
)])
)
residual
=
self
.
unet
(
image
,
t_orig
)
cur_residual
=
0
if
t
%
4
=
=
0
:
prev_image
=
image
prev_image
=
image
ets
=
[]
for
i
in
range
(
len
(
warmup_time_steps
)):
t
=
warmup_time_steps
[
i
]
*
torch
.
ones
(
image
.
shape
[
0
])
t_next
=
(
warmup_time_steps
[
i
+
1
]
if
i
<
len
(
warmup_time_steps
)
-
1
else
warmup_time_steps
[
-
1
])
*
torch
.
ones
(
image
.
shape
[
0
])
residual
=
model
(
image
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
image
=
self
.
noise_scheduler
.
step_warm_up
(
residual
,
prev_image
,
t
,
num_inference_steps
)
residual
=
residual
.
to
(
"cpu"
)
if
i
%
4
==
0
:
timesteps
=
self
.
noise_scheduler
.
get_time_steps
(
num_inference_steps
)
cur_residual
+=
1
/
6
*
residual
for
t
in
tqdm
.
tqdm
(
range
(
len
(
timesteps
))):
ets
.
append
(
residual
)
t_orig
=
timesteps
[
t
]
prev_image
=
image
residual
=
self
.
unet
(
image
,
t_orig
)
elif
(
i
-
1
)
%
4
==
0
:
cur_residual
+=
1
/
3
*
residual
image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
)
elif
(
i
-
2
)
%
4
==
0
:
cur_residual
+=
1
/
3
*
residual
elif
(
i
-
3
)
%
4
==
0
:
cur_residual
+=
1
/
6
*
residual
residual
=
cur_residual
cur_residual
=
0
image
=
image
.
to
(
"cpu"
)
t_2
=
warmup_time_steps
[
4
*
(
i
//
4
)]
*
torch
.
ones
(
image
.
shape
[
0
])
image
=
self
.
noise_scheduler
.
transfer
(
prev_image
.
to
(
"cpu"
),
t_2
,
t_next
,
residual
)
step_idx
=
len
(
seq
)
-
4
while
step_idx
>=
0
:
i
=
seq
[
step_idx
]
j
=
seq_next
[
step_idx
]
t
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
i
)
t_next
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
j
)
residual
=
model
(
image
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
residual
=
residual
.
to
(
"cpu"
)
ets
.
append
(
residual
)
residual
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
img_next
=
self
.
noise_scheduler
.
transfer
(
image
.
to
(
"cpu"
),
t
,
t_next
,
residual
)
image
=
img_next
step_idx
=
step_idx
-
1
return
image
return
image
src/diffusers/schedulers/scheduling_pndm.py
View file @
559b8cbf
...
@@ -55,6 +55,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -55,6 +55,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
# hardcode for now
self
.
pndm_order
=
4
self
.
cur_residual
=
0
# running values
self
.
ets
=
[]
self
.
warmup_time_steps
=
{}
self
.
time_steps
=
{}
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
...
@@ -83,51 +92,62 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -83,51 +92,62 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
self
.
one
return
self
.
one
return
self
.
alphas_cumprod
[
time_step
]
return
self
.
alphas_cumprod
[
time_step
]
def
step
(
self
,
img
,
t_start
,
t_end
,
model
,
ets
):
def
get_warmup_time_steps
(
self
,
num_inference_steps
):
# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets)
if
num_inference_steps
in
self
.
warmup_time_steps
:
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
return
self
.
warmup_time_steps
[
num_inference_steps
]
t_next
,
t
=
t_start
,
t_end
noise_
=
model
(
img
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
noise_
=
noise_
.
to
(
"cpu"
)
t_list
=
[
t
,
(
t
+
t_next
)
/
2
,
t_next
]
warmup_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
if
len
(
ets
)
>
2
:
self
.
warmup_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
ets
.
append
(
noise_
)
noise
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
return
self
.
warmup_time_steps
[
num_inference_steps
]
else
:
noise
=
self
.
runge_kutta
(
img
,
t_list
,
model
,
ets
,
noise_
)
img_next
=
self
.
transfer
(
img
.
to
(
"cpu"
),
t
,
t_next
,
noise
)
def
get_time_steps
(
self
,
num_inference_steps
):
return
img_next
,
ets
if
num_inference_steps
in
self
.
time_steps
:
return
self
.
time_steps
[
num_inference_steps
]
def
runge_kutta
(
self
,
x
,
t_list
,
model
,
ets
,
noise_
):
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
model
=
model
.
to
(
"cuda"
)
self
.
time_steps
[
num_inference_steps
]
=
list
(
reversed
(
inference_step_times
[:
-
3
]))
x
=
x
.
to
(
"cpu"
)
e_1
=
noise_
return
self
.
time_steps
[
num_inference_steps
]
ets
.
append
(
e_1
)
x_2
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_1
)
e_2
=
model
(
x_2
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
def
step_warm_up
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
e_2
=
e_2
.
to
(
"cpu"
)
warmup_time_steps
=
self
.
get_warmup_time_steps
(
num_inference_steps
)
x_3
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_2
)
e_3
=
model
(
x_3
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
t_prev
=
warmup_time_steps
[
t
//
4
*
4
]
e_3
=
e_3
.
to
(
"cpu"
)
t_next
=
warmup_time_steps
[
min
(
t
+
1
,
len
(
warmup_time_steps
)
-
1
)]
x_4
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
2
],
e_3
)
e_4
=
model
(
x_4
.
to
(
"cuda"
),
t_list
[
2
].
to
(
"cuda"
))
if
t
%
4
==
0
:
e_4
=
e_4
.
to
(
"cpu"
)
self
.
cur_residual
+=
1
/
6
*
residual
self
.
ets
.
append
(
residual
)
elif
(
t
-
1
)
%
4
==
0
:
self
.
cur_residual
+=
1
/
3
*
residual
elif
(
t
-
2
)
%
4
==
0
:
self
.
cur_residual
+=
1
/
3
*
residual
elif
(
t
-
3
)
%
4
==
0
:
residual
=
self
.
cur_residual
+
1
/
6
*
residual
self
.
cur_residual
=
0
et
=
(
1
/
6
)
*
(
e_1
+
2
*
e_2
+
2
*
e_3
+
e_4
)
r
et
urn
self
.
transfer
(
image
,
t_prev
,
t_next
,
residual
)
return
et
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
timesteps
=
self
.
get_time_steps
(
num_inference_steps
)
t_prev
=
timesteps
[
t
]
t_next
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
self
.
ets
.
append
(
residual
)
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
self
.
transfer
(
image
,
t_prev
,
t_next
,
residual
)
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
alphas_cump
=
self
.
alphas_cumprod
# TODO(Patrick): clean up to be compatible with numpy and give better names
at
=
alphas_cump
[
t
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
alphas_cump
=
self
.
alphas_cumprod
.
to
(
x
.
device
)
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
x_delta
=
(
at_next
-
at
)
*
((
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
x_delta
=
(
at_next
-
at
)
*
((
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment