Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3a32b8c9
Commit
3a32b8c9
authored
Jul 19, 2022
by
Patrick von Platen
Browse files
align API
parent
c3a15437
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
26 deletions
+15
-26
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+2
-5
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+3
-6
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+4
-9
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+6
-6
No files found.
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
3a32b8c9
...
@@ -27,6 +27,7 @@ class DDIMPipeline(DiffusionPipeline):
...
@@ -27,6 +27,7 @@ class DDIMPipeline(DiffusionPipeline):
scheduler
=
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
if
torch_device
is
None
:
...
@@ -46,11 +47,7 @@ class DDIMPipeline(DiffusionPipeline):
...
@@ -46,11 +47,7 @@ class DDIMPipeline(DiffusionPipeline):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise model_output
# 1. predict noise model_output
with
torch
.
no_grad
():
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
model_output
=
self
.
unet
(
image
,
t
)
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
# do x_t -> x_t-1
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
3a32b8c9
...
@@ -27,6 +27,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -27,6 +27,7 @@ class DDPMPipeline(DiffusionPipeline):
scheduler
=
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -45,11 +46,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -45,11 +46,7 @@ class DDPMPipeline(DiffusionPipeline):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise model_output
# 1. predict noise model_output
with
torch
.
no_grad
():
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
model_output
=
self
.
unet
(
image
,
t
)
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
pred_prev_image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
...
@@ -63,4 +60,4 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -63,4 +60,4 @@ class DDPMPipeline(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
return
image
return
{
"sample"
:
image
}
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
3a32b8c9
...
@@ -27,6 +27,7 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -27,6 +27,7 @@ class PNDMPipeline(DiffusionPipeline):
scheduler
=
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
# For more information on the sampling method you can take a look at Algorithm 2 of
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
...
@@ -45,21 +46,15 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -45,21 +46,15 @@ class PNDMPipeline(DiffusionPipeline):
prk_time_steps
=
self
.
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
prk_time_steps
=
self
.
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
for
t
in
tqdm
(
range
(
len
(
prk_time_steps
))):
for
t
in
tqdm
(
range
(
len
(
prk_time_steps
))):
t_orig
=
prk_time_steps
[
t
]
t_orig
=
prk_time_steps
[
t
]
model_output
=
self
.
unet
(
image
,
t_orig
)
model_output
=
self
.
unet
(
image
,
t_orig
)[
"sample"
]
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
image
=
self
.
scheduler
.
step_prk
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step_prk
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
timesteps
=
self
.
scheduler
.
get_time_steps
(
num_inference_steps
)
timesteps
=
self
.
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
tqdm
(
range
(
len
(
timesteps
))):
for
t
in
tqdm
(
range
(
len
(
timesteps
))):
t_orig
=
timesteps
[
t
]
t_orig
=
timesteps
[
t
]
model_output
=
self
.
unet
(
image
,
t_orig
)
model_output
=
self
.
unet
(
image
,
t_orig
)[
"sample"
]
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
image
=
self
.
scheduler
.
step_plms
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step_plms
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
return
image
return
{
"sample"
:
image
}
tests/test_modeling_utils.py
View file @
3a32b8c9
...
@@ -665,9 +665,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -665,9 +665,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
[
"sample"
]
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
new_ddpm
(
generator
=
generator
)
new_image
=
new_ddpm
(
generator
=
generator
)
[
"sample"
]
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
...
@@ -683,9 +683,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -683,9 +683,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
[
"sample"
]
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)
[
"sample"
]
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
...
@@ -700,7 +700,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -700,7 +700,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
@@ -759,7 +759,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -759,7 +759,7 @@ class PipelineTesterMixin(unittest.TestCase):
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
pndm
(
generator
=
generator
)
image
=
pndm
(
generator
=
generator
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment