Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
559b8cbf
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3228eb16091b72c316a90dc8acb0a8638e81fd34"
Commit
559b8cbf
authored
Jun 14, 2022
by
Patrick von Platen
Browse files
finish pndm
parent
7d8bf1a9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
82 deletions
+66
-82
src/diffusers/pipelines/pipeline_pndm.py
src/diffusers/pipelines/pipeline_pndm.py
+13
-49
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+53
-33
No files found.
src/diffusers/pipelines/pipeline_pndm.py
View file @
559b8cbf
...
@@ -32,9 +32,6 @@ class PNDM(DiffusionPipeline):
...
@@ -32,9 +32,6 @@ class PNDM(DiffusionPipeline):
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
...
@@ -44,55 +41,22 @@ class PNDM(DiffusionPipeline):
...
@@ -44,55 +41,22 @@ class PNDM(DiffusionPipeline):
)
)
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
seq
=
list
(
inference_step_times
)
warmup_time_steps
=
self
.
noise_scheduler
.
get_warmup_time_steps
(
num_inference_steps
)
seq_next
=
[
-
1
]
+
list
(
seq
[:
-
1
])
prev_image
=
image
model
=
self
.
unet
for
t
in
tqdm
.
tqdm
(
range
(
len
(
warmup_time_steps
))):
t_orig
=
warmup_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
warmup_time_steps
=
list
(
reversed
([(
t
+
5
)
//
10
*
10
for
t
in
range
(
seq
[
-
4
],
seq
[
-
1
],
5
)]))
if
t
%
4
==
0
:
prev_image
=
image
cur_residual
=
0
image
=
self
.
noise_scheduler
.
step_warm_up
(
residual
,
prev_image
,
t
,
num_inference_steps
)
prev_image
=
image
ets
=
[]
for
i
in
range
(
len
(
warmup_time_steps
)):
t
=
warmup_time_steps
[
i
]
*
torch
.
ones
(
image
.
shape
[
0
])
t_next
=
(
warmup_time_steps
[
i
+
1
]
if
i
<
len
(
warmup_time_steps
)
-
1
else
warmup_time_steps
[
-
1
])
*
torch
.
ones
(
image
.
shape
[
0
])
residual
=
model
(
image
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
timesteps
=
self
.
noise_scheduler
.
get_time_steps
(
num_inference_steps
)
residual
=
residual
.
to
(
"cpu"
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
timesteps
))):
t_orig
=
timesteps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
if
i
%
4
==
0
:
image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
)
cur_residual
+=
1
/
6
*
residual
ets
.
append
(
residual
)
prev_image
=
image
elif
(
i
-
1
)
%
4
==
0
:
cur_residual
+=
1
/
3
*
residual
elif
(
i
-
2
)
%
4
==
0
:
cur_residual
+=
1
/
3
*
residual
elif
(
i
-
3
)
%
4
==
0
:
cur_residual
+=
1
/
6
*
residual
residual
=
cur_residual
cur_residual
=
0
image
=
image
.
to
(
"cpu"
)
t_2
=
warmup_time_steps
[
4
*
(
i
//
4
)]
*
torch
.
ones
(
image
.
shape
[
0
])
image
=
self
.
noise_scheduler
.
transfer
(
prev_image
.
to
(
"cpu"
),
t_2
,
t_next
,
residual
)
step_idx
=
len
(
seq
)
-
4
while
step_idx
>=
0
:
i
=
seq
[
step_idx
]
j
=
seq_next
[
step_idx
]
t
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
i
)
t_next
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
j
)
residual
=
model
(
image
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
residual
=
residual
.
to
(
"cpu"
)
ets
.
append
(
residual
)
residual
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
img_next
=
self
.
noise_scheduler
.
transfer
(
image
.
to
(
"cpu"
),
t
,
t_next
,
residual
)
image
=
img_next
step_idx
=
step_idx
-
1
return
image
return
image
src/diffusers/schedulers/scheduling_pndm.py
View file @
559b8cbf
...
@@ -55,6 +55,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -55,6 +55,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
# hardcode for now
self
.
pndm_order
=
4
self
.
cur_residual
=
0
# running values
self
.
ets
=
[]
self
.
warmup_time_steps
=
{}
self
.
time_steps
=
{}
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
...
@@ -83,51 +92,62 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -83,51 +92,62 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
self
.
one
return
self
.
one
return
self
.
alphas_cumprod
[
time_step
]
return
self
.
alphas_cumprod
[
time_step
]
def
step
(
self
,
img
,
t_start
,
t_end
,
model
,
ets
):
def
get_warmup_time_steps
(
self
,
num_inference_steps
):
# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets)
if
num_inference_steps
in
self
.
warmup_time_steps
:
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
return
self
.
warmup_time_steps
[
num_inference_steps
]
t_next
,
t
=
t_start
,
t_end
noise_
=
model
(
img
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
noise_
=
noise_
.
to
(
"cpu"
)
t_list
=
[
t
,
(
t
+
t_next
)
/
2
,
t_next
]
warmup_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
if
len
(
ets
)
>
2
:
self
.
warmup_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
ets
.
append
(
noise_
)
noise
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
return
self
.
warmup_time_steps
[
num_inference_steps
]
else
:
noise
=
self
.
runge_kutta
(
img
,
t_list
,
model
,
ets
,
noise_
)
img_next
=
self
.
transfer
(
img
.
to
(
"cpu"
),
t
,
t_next
,
noise
)
def
get_time_steps
(
self
,
num_inference_steps
):
return
img_next
,
ets
if
num_inference_steps
in
self
.
time_steps
:
return
self
.
time_steps
[
num_inference_steps
]
def
runge_kutta
(
self
,
x
,
t_list
,
model
,
ets
,
noise_
):
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
model
=
model
.
to
(
"cuda"
)
self
.
time_steps
[
num_inference_steps
]
=
list
(
reversed
(
inference_step_times
[:
-
3
]))
x
=
x
.
to
(
"cpu"
)
e_1
=
noise_
return
self
.
time_steps
[
num_inference_steps
]
ets
.
append
(
e_1
)
x_2
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_1
)
e_2
=
model
(
x_2
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
def
step_warm_up
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
e_2
=
e_2
.
to
(
"cpu"
)
warmup_time_steps
=
self
.
get_warmup_time_steps
(
num_inference_steps
)
x_3
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_2
)
e_3
=
model
(
x_3
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
t_prev
=
warmup_time_steps
[
t
//
4
*
4
]
e_3
=
e_3
.
to
(
"cpu"
)
t_next
=
warmup_time_steps
[
min
(
t
+
1
,
len
(
warmup_time_steps
)
-
1
)]
x_4
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
2
],
e_3
)
e_4
=
model
(
x_4
.
to
(
"cuda"
),
t_list
[
2
].
to
(
"cuda"
))
if
t
%
4
==
0
:
e_4
=
e_4
.
to
(
"cpu"
)
self
.
cur_residual
+=
1
/
6
*
residual
self
.
ets
.
append
(
residual
)
elif
(
t
-
1
)
%
4
==
0
:
self
.
cur_residual
+=
1
/
3
*
residual
elif
(
t
-
2
)
%
4
==
0
:
self
.
cur_residual
+=
1
/
3
*
residual
elif
(
t
-
3
)
%
4
==
0
:
residual
=
self
.
cur_residual
+
1
/
6
*
residual
self
.
cur_residual
=
0
et
=
(
1
/
6
)
*
(
e_1
+
2
*
e_2
+
2
*
e_3
+
e_4
)
r
et
urn
self
.
transfer
(
image
,
t_prev
,
t_next
,
residual
)
return
et
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
timesteps
=
self
.
get_time_steps
(
num_inference_steps
)
t_prev
=
timesteps
[
t
]
t_next
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
self
.
ets
.
append
(
residual
)
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
self
.
transfer
(
image
,
t_prev
,
t_next
,
residual
)
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
alphas_cump
=
self
.
alphas_cumprod
# TODO(Patrick): clean up to be compatible with numpy and give better names
at
=
alphas_cump
[
t
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
alphas_cump
=
self
.
alphas_cumprod
.
to
(
x
.
device
)
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
x_delta
=
(
at_next
-
at
)
*
((
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
x_delta
=
(
at_next
-
at
)
*
((
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment