Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1b42732c
Unverified
Commit
1b42732c
authored
Jul 20, 2022
by
Anton Lozhkov
Committed by
GitHub
Jul 20, 2022
Browse files
PIL-ify the pipeline outputs (#111)
parent
9e9d2dbc
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
49 additions
and
10 deletions
+49
-10
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+13
-0
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+3
-1
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+3
-0
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+3
-6
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+3
-1
src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
...diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
+3
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+21
-1
No files found.
src/diffusers/pipeline_utils.py
View file @
1b42732c
...
@@ -19,6 +19,7 @@ import os
...
@@ -19,6 +19,7 @@ import os
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
PIL
import
Image
from
.configuration_utils
import
ConfigMixin
from
.configuration_utils
import
ConfigMixin
from
.utils
import
DIFFUSERS_CACHE
,
logging
from
.utils
import
DIFFUSERS_CACHE
,
logging
...
@@ -189,3 +190,15 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -189,3 +190,15 @@ class DiffusionPipeline(ConfigMixin):
# 5. Instantiate the pipeline
# 5. Instantiate the pipeline
model
=
pipeline_class
(
**
init_kwargs
)
model
=
pipeline_class
(
**
init_kwargs
)
return
model
return
model
@
staticmethod
def
numpy_to_pil
(
images
):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if
images
.
ndim
==
3
:
images
=
images
[
None
,
...]
images
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
pil_images
=
[
Image
.
fromarray
(
image
)
for
image
in
images
]
return
pil_images
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
1b42732c
...
@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
output_type
=
"numpy"
):
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -56,5 +56,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -56,5 +56,7 @@ class DDPMPipeline(DiffusionPipeline):
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
1b42732c
...
@@ -30,6 +30,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
...
@@ -30,6 +30,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
eta
=
0.0
,
eta
=
0.0
,
guidance_scale
=
1.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
num_inference_steps
=
50
,
output_type
=
"numpy"
,
):
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
...
@@ -86,6 +87,8 @@ class LatentDiffusionPipeline(DiffusionPipeline):
...
@@ -86,6 +87,8 @@ class LatentDiffusionPipeline(DiffusionPipeline):
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
1b42732c
...
@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
...
@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
self
,
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
,
output_type
=
"numpy"
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
,
):
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
...
@@ -47,5 +42,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
...
@@ -47,5 +42,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
1b42732c
...
@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
,
output_type
=
"numpy"
):
# For more information on the sampling method you can take a look at Algorithm 2 of
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if
torch_device
is
None
:
if
torch_device
is
None
:
...
@@ -59,5 +59,7 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -59,5 +59,7 @@ class PNDMPipeline(DiffusionPipeline):
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
View file @
1b42732c
...
@@ -11,7 +11,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
...
@@ -11,7 +11,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
num_inference_steps
=
2000
,
generator
=
None
):
def
__call__
(
self
,
num_inference_steps
=
2000
,
generator
=
None
,
output_type
=
"numpy"
):
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
img_size
=
self
.
model
.
config
.
image_size
img_size
=
self
.
model
.
config
.
image_size
...
@@ -47,5 +47,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
...
@@ -47,5 +47,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
sample
=
sample
.
clamp
(
0
,
1
)
sample
=
sample
.
clamp
(
0
,
1
)
sample
=
sample
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
sample
=
sample
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
sample
=
self
.
numpy_to_pil
(
sample
)
return
{
"sample"
:
sample
}
return
{
"sample"
:
sample
}
tests/test_modeling_utils.py
View file @
1b42732c
...
@@ -18,11 +18,11 @@ import inspect
...
@@ -18,11 +18,11 @@ import inspect
import
math
import
math
import
tempfile
import
tempfile
import
unittest
import
unittest
from
atexit
import
register
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
PIL
from
diffusers
import
UNetConditionalModel
# noqa: F401 TODO(Patrick) - need to write tests with it
from
diffusers
import
UNetConditionalModel
# noqa: F401 TODO(Patrick) - need to write tests with it
from
diffusers
import
(
from
diffusers
import
(
AutoencoderKL
,
AutoencoderKL
,
...
@@ -728,6 +728,26 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -728,6 +728,26 @@ class PipelineTesterMixin(unittest.TestCase):
assert
np
.
abs
(
image
-
new_image
).
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
np
.
abs
(
image
-
new_image
).
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
def
test_output_format
(
self
):
model_path
=
"google/ddpm-cifar10-32"
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
)
generator
=
torch
.
manual_seed
(
0
)
images
=
pipe
(
generator
=
generator
)[
"sample"
]
assert
images
.
shape
==
(
1
,
32
,
32
,
3
)
assert
isinstance
(
images
,
np
.
ndarray
)
images
=
pipe
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
assert
images
.
shape
==
(
1
,
32
,
32
,
3
)
assert
isinstance
(
images
,
np
.
ndarray
)
images
=
pipe
(
generator
=
generator
,
output_type
=
"pil"
)[
"sample"
]
assert
isinstance
(
images
,
list
)
assert
len
(
images
)
==
1
assert
isinstance
(
images
[
0
],
PIL
.
Image
.
Image
)
@
slow
@
slow
def
test_ddpm_cifar10
(
self
):
def
test_ddpm_cifar10
(
self
):
model_id
=
"google/ddpm-cifar10-32"
model_id
=
"google/ddpm-cifar10-32"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment