Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7a1323b6
Commit
7a1323b6
authored
Jun 08, 2022
by
Patrick von Platen
Browse files
add first version of ddim
parent
86064df7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
4 deletions
+38
-4
models/vision/ddim/modeling_ddim.py
models/vision/ddim/modeling_ddim.py
+8
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+30
-0
No files found.
models/vision/ddim/modeling_ddim.py
View file @
7a1323b6
...
@@ -32,11 +32,16 @@ class DDIM(DiffusionPipeline):
...
@@ -32,11 +32,16 @@ class DDIM(DiffusionPipeline):
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
inference_time_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
inference_time_steps
=
50
):
seq
=
range
(
0
,
self
.
num_timesteps
,
self
.
num_timesteps
//
inference_time_steps
)
# eta is η in paper
b
=
self
.
noise_scheduler
.
betas
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_timesteps
=
self
.
noise_scheduler
.
num_timesteps
seq
=
range
(
0
,
num_timesteps
,
num_timesteps
//
inference_time_steps
)
b
=
self
.
noise_scheduler
.
betas
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
x
=
self
.
noise_scheduler
.
sample_noise
((
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
x
=
self
.
noise_scheduler
.
sample_noise
((
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
...
@@ -63,5 +68,4 @@ class DDIM(DiffusionPipeline):
...
@@ -63,5 +68,4 @@ class DDIM(DiffusionPipeline):
xt_next
=
at_next
.
sqrt
()
*
x0_t
+
c1
*
torch
.
randn_like
(
x
)
+
c2
*
et
xt_next
=
at_next
.
sqrt
()
*
x0_t
+
c1
*
torch
.
randn_like
(
x
)
+
c2
*
et
xs
.
append
(
xt_next
.
to
(
'cpu'
))
xs
.
append
(
xt_next
.
to
(
'cpu'
))
import
ipdb
;
ipdb
.
set_trace
()
return
xt_next
return
xs
,
x0_preds
tests/test_modeling_utils.py
View file @
7a1323b6
...
@@ -25,6 +25,7 @@ import torch
...
@@ -25,6 +25,7 @@ import torch
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
models.vision.ddpm.modeling_ddpm
import
DDPM
from
models.vision.ddpm.modeling_ddpm
import
DDPM
from
models.vision.ddim.modeling_ddim
import
DDIM
global_rng
=
random
.
Random
()
global_rng
=
random
.
Random
()
...
@@ -205,6 +206,7 @@ class SamplerTesterMixin(unittest.TestCase):
...
@@ -205,6 +206,7 @@ class SamplerTesterMixin(unittest.TestCase):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
...
@@ -241,3 +243,31 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -241,3 +243,31 @@ class PipelineTesterMixin(unittest.TestCase):
new_image
=
ddpm_from_hub
(
generator
=
generator
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
def
test_ddpm_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
ddpm
=
DDPM
.
from_pretrained
(
model_id
)
image
=
ddpm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
([
0.2250
,
0.3375
,
0.2360
,
0.0930
,
0.3440
,
0.3156
,
0.1937
,
0.3585
,
0.1761
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ddim_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
ddim
=
DDIM
.
from_pretrained
(
model_id
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
([
-
0.7688
,
-
0.7690
,
-
0.7597
,
-
0.7660
,
-
0.7713
,
-
0.7531
,
-
0.7009
,
-
0.7098
,
-
0.7350
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment