Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
3a32b8c9
Commit
3a32b8c9
authored
Jul 19, 2022
by
Patrick von Platen
Browse files
align API
parent
c3a15437
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
26 deletions
+15
-26
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+2
-5
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+3
-6
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+4
-9
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+6
-6
No files found.
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
3a32b8c9
...
@@ -27,6 +27,7 @@ class DDIMPipeline(DiffusionPipeline):
...
@@ -27,6 +27,7 @@ class DDIMPipeline(DiffusionPipeline):
scheduler
=
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
if
torch_device
is
None
:
...
@@ -46,11 +47,7 @@ class DDIMPipeline(DiffusionPipeline):
...
@@ -46,11 +47,7 @@ class DDIMPipeline(DiffusionPipeline):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise model_output
# 1. predict noise model_output
with
torch
.
no_grad
():
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
model_output
=
self
.
unet
(
image
,
t
)
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
# do x_t -> x_t-1
...
...
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
3a32b8c9
...
@@ -27,6 +27,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -27,6 +27,7 @@ class DDPMPipeline(DiffusionPipeline):
scheduler
=
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -45,11 +46,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -45,11 +46,7 @@ class DDPMPipeline(DiffusionPipeline):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise model_output
# 1. predict noise model_output
with
torch
.
no_grad
():
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
model_output
=
self
.
unet
(
image
,
t
)
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
pred_prev_image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
...
@@ -63,4 +60,4 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -63,4 +60,4 @@ class DDPMPipeline(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
return
image
return
{
"sample"
:
image
}
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
3a32b8c9
...
@@ -27,6 +27,7 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -27,6 +27,7 @@ class PNDMPipeline(DiffusionPipeline):
scheduler
=
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
# For more information on the sampling method you can take a look at Algorithm 2 of
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
...
@@ -45,21 +46,15 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -45,21 +46,15 @@ class PNDMPipeline(DiffusionPipeline):
prk_time_steps
=
self
.
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
prk_time_steps
=
self
.
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
for
t
in
tqdm
(
range
(
len
(
prk_time_steps
))):
for
t
in
tqdm
(
range
(
len
(
prk_time_steps
))):
t_orig
=
prk_time_steps
[
t
]
t_orig
=
prk_time_steps
[
t
]
model_output
=
self
.
unet
(
image
,
t_orig
)
model_output
=
self
.
unet
(
image
,
t_orig
)[
"sample"
]
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
image
=
self
.
scheduler
.
step_prk
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step_prk
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
timesteps
=
self
.
scheduler
.
get_time_steps
(
num_inference_steps
)
timesteps
=
self
.
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
tqdm
(
range
(
len
(
timesteps
))):
for
t
in
tqdm
(
range
(
len
(
timesteps
))):
t_orig
=
timesteps
[
t
]
t_orig
=
timesteps
[
t
]
model_output
=
self
.
unet
(
image
,
t_orig
)
model_output
=
self
.
unet
(
image
,
t_orig
)[
"sample"
]
if
isinstance
(
model_output
,
dict
):
model_output
=
model_output
[
"sample"
]
image
=
self
.
scheduler
.
step_plms
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step_plms
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
return
image
return
{
"sample"
:
image
}
tests/test_modeling_utils.py
View file @
3a32b8c9
...
@@ -665,9 +665,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -665,9 +665,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
[
"sample"
]
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
new_ddpm
(
generator
=
generator
)
new_image
=
new_ddpm
(
generator
=
generator
)
[
"sample"
]
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
...
@@ -683,9 +683,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -683,9 +683,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
[
"sample"
]
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)
[
"sample"
]
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
...
@@ -700,7 +700,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -700,7 +700,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
@@ -759,7 +759,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -759,7 +759,7 @@ class PipelineTesterMixin(unittest.TestCase):
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
pndm
(
generator
=
generator
)
image
=
pndm
(
generator
=
generator
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment