Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
809591b7
Commit
809591b7
authored
Jun 13, 2022
by
Patrick von Platen
Browse files
improve pndm
parent
11631e81
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
17 deletions
+57
-17
src/diffusers/pipelines/pipeline_pndm.py
src/diffusers/pipelines/pipeline_pndm.py
+25
-4
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+12
-12
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-1
No files found.
src/diffusers/pipelines/pipeline_pndm.py
View file @
809591b7
...
@@ -40,7 +40,7 @@ class PNDM(DiffusionPipeline):
...
@@ -40,7 +40,7 @@ class PNDM(DiffusionPipeline):
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
image
=
torch
.
randn
(
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
#
generator=
torch.manual_seed(0)
generator
=
generator
,
)
)
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
...
@@ -53,9 +53,30 @@ class PNDM(DiffusionPipeline):
...
@@ -53,9 +53,30 @@ class PNDM(DiffusionPipeline):
t
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
i
)
t
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
i
)
t_next
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
j
)
t_next
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
j
)
with
torch
.
no_grad
():
residual
=
model
(
image
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
t_start
,
t_end
=
t_next
,
t
residual
=
residual
.
to
(
"cpu"
)
img_next
,
ets
=
self
.
noise_scheduler
.
step
(
image
,
t_start
,
t_end
,
model
,
ets
)
t_list
=
[
t
,
(
t
+
t_next
)
/
2
,
t_next
]
if
len
(
ets
)
<=
2
:
ets
.
append
(
residual
)
image
=
image
.
to
(
"cpu"
)
x_2
=
self
.
noise_scheduler
.
transfer
(
image
,
t_list
[
0
],
t_list
[
1
],
residual
)
e_2
=
model
(
x_2
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
)).
to
(
"cpu"
)
x_3
=
self
.
noise_scheduler
.
transfer
(
image
,
t_list
[
0
],
t_list
[
1
],
e_2
)
e_3
=
model
(
x_3
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
)).
to
(
"cpu"
)
x_4
=
self
.
noise_scheduler
.
transfer
(
image
,
t_list
[
0
],
t_list
[
2
],
e_3
)
e_4
=
model
(
x_4
.
to
(
"cuda"
),
t_list
[
2
].
to
(
"cuda"
)).
to
(
"cpu"
)
residual
=
(
1
/
6
)
*
(
residual
+
2
*
e_2
+
2
*
e_3
+
e_4
)
else
:
ets
.
append
(
residual
)
residual
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
img_next
=
self
.
noise_scheduler
.
transfer
(
image
.
to
(
"cpu"
),
t
,
t_next
,
residual
)
# with torch.no_grad():
# t_start, t_end = t_next, t
# img_next, ets = self.noise_scheduler.step(image, t_start, t_end, model, ets)
image
=
img_next
image
=
img_next
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
809591b7
...
@@ -88,35 +88,34 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -88,35 +88,34 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
t_next
,
t
=
t_start
,
t_end
t_next
,
t
=
t_start
,
t_end
noise_
=
model
(
img
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
noise_
=
noise_
.
to
(
"cpu"
)
t_list
=
[
t
,
(
t
+
t_next
)
/
2
,
t_next
]
t_list
=
[
t
,
(
t
+
t_next
)
/
2
,
t_next
]
alphas_cump
=
self
.
alphas_cumprod
if
len
(
ets
)
>
2
:
if
len
(
ets
)
>
2
:
noise_
=
model
(
img
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
noise_
=
noise_
.
to
(
"cpu"
)
ets
.
append
(
noise_
)
ets
.
append
(
noise_
)
noise
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
noise
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
else
:
else
:
noise
=
self
.
runge_kutta
(
img
,
t_list
,
model
,
alphas_cump
,
ets
)
noise
=
self
.
runge_kutta
(
img
,
t_list
,
model
,
ets
,
noise_
)
img_next
=
self
.
transfer
(
img
.
to
(
"cpu"
),
t
,
t_next
,
noise
,
alphas_cump
)
img_next
=
self
.
transfer
(
img
.
to
(
"cpu"
),
t
,
t_next
,
noise
)
return
img_next
,
ets
return
img_next
,
ets
def
runge_kutta
(
self
,
x
,
t_list
,
model
,
alphas_cump
,
ets
):
def
runge_kutta
(
self
,
x
,
t_list
,
model
,
ets
,
noise_
):
model
=
model
.
to
(
"cuda"
)
model
=
model
.
to
(
"cuda"
)
x
=
x
.
to
(
"cpu"
)
x
=
x
.
to
(
"cpu"
)
e_1
=
model
(
x
.
to
(
"cuda"
),
t_list
[
0
].
to
(
"cuda"
))
e_1
=
noise_
e_1
=
e_1
.
to
(
"cpu"
)
ets
.
append
(
e_1
)
ets
.
append
(
e_1
)
x_2
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_1
,
alphas_cump
)
x_2
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_1
)
e_2
=
model
(
x_2
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
e_2
=
model
(
x_2
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
e_2
=
e_2
.
to
(
"cpu"
)
e_2
=
e_2
.
to
(
"cpu"
)
x_3
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_2
,
alphas_cump
)
x_3
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_2
)
e_3
=
model
(
x_3
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
e_3
=
model
(
x_3
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
e_3
=
e_3
.
to
(
"cpu"
)
e_3
=
e_3
.
to
(
"cpu"
)
x_4
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
2
],
e_3
,
alphas_cump
)
x_4
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
2
],
e_3
)
e_4
=
model
(
x_4
.
to
(
"cuda"
),
t_list
[
2
].
to
(
"cuda"
))
e_4
=
model
(
x_4
.
to
(
"cuda"
),
t_list
[
2
].
to
(
"cuda"
))
e_4
=
e_4
.
to
(
"cpu"
)
e_4
=
e_4
.
to
(
"cpu"
)
...
@@ -125,7 +124,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -125,7 +124,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
et
return
et
def
transfer
(
self
,
x
,
t
,
t_next
,
et
,
alphas_cump
):
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
alphas_cump
=
self
.
alphas_cumprod
at
=
alphas_cump
[
t
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at
=
alphas_cump
[
t
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
...
...
tests/test_modeling_utils.py
View file @
809591b7
...
@@ -19,7 +19,7 @@ import unittest
...
@@ -19,7 +19,7 @@ import unittest
import
torch
import
torch
from
diffusers
import
DDIM
,
DDPM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
UNetModel
from
diffusers
import
DDIM
,
DDPM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
UNetModel
,
PNDM
,
PNDMScheduler
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
@@ -178,6 +178,25 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -178,6 +178,25 @@ class PipelineTesterMixin(unittest.TestCase):
)
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_pndm_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetModel
.
from_pretrained
(
model_id
)
noise_scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDM
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
image
=
pndm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
(
[
-
0.7888
,
-
0.7870
,
-
0.7759
,
-
0.7823
,
-
0.8014
,
-
0.7608
,
-
0.6818
,
-
0.7130
,
-
0.7471
]
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
slow
def
test_ldm_text2img
(
self
):
def
test_ldm_text2img
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
model_id
=
"fusing/latent-diffusion-text2im-large"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment