Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
24895a1f
Unverified
Commit
24895a1f
authored
Nov 09, 2022
by
Anton Lozhkov
Committed by
GitHub
Nov 09, 2022
Browse files
Fix cpu offloading (#1177)
* Fix cpu offloading * get offloaded devices locally for SD pipelines
parent
598ff76b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
107 additions
and
61 deletions
+107
-61
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+0
-2
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+28
-12
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+28
-9
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+32
-15
tests/pipelines/stable_diffusion/test_stable_diffusion.py
tests/pipelines/stable_diffusion/test_stable_diffusion.py
+6
-4
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
...pelines/stable_diffusion/test_stable_diffusion_img2img.py
+4
-11
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
...pelines/stable_diffusion/test_stable_diffusion_inpaint.py
+9
-8
No files found.
src/diffusers/pipeline_utils.py
View file @
24895a1f
...
@@ -230,8 +230,6 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -230,8 +230,6 @@ class DiffusionPipeline(ConfigMixin):
for
name
in
module_names
.
keys
():
for
name
in
module_names
.
keys
():
module
=
getattr
(
self
,
name
)
module
=
getattr
(
self
,
name
)
if
isinstance
(
module
,
torch
.
nn
.
Module
):
if
isinstance
(
module
,
torch
.
nn
.
Module
):
if
module
.
device
==
torch
.
device
(
"meta"
):
return
torch
.
device
(
"cpu"
)
return
module
.
device
return
module
.
device
return
torch
.
device
(
"cpu"
)
return
torch
.
device
(
"cpu"
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
24895a1f
...
@@ -195,6 +195,24 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -195,6 +195,24 @@ class StableDiffusionPipeline(DiffusionPipeline):
if
cpu_offloaded_model
is
not
None
:
if
cpu_offloaded_model
is
not
None
:
cpu_offload
(
cpu_offloaded_model
,
device
)
cpu_offload
(
cpu_offloaded_model
,
device
)
@
property
def
_execution_device
(
self
):
r
"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
unet
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
unet
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
self
,
self
,
...
@@ -286,6 +304,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -286,6 +304,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
f
"
{
type
(
callback_steps
)
}
."
f
"
{
type
(
callback_steps
)
}
."
)
)
device
=
self
.
_execution_device
# get prompt text embeddings
# get prompt text embeddings
text_inputs
=
self
.
tokenizer
(
text_inputs
=
self
.
tokenizer
(
prompt
,
prompt
,
...
@@ -302,7 +322,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -302,7 +322,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
self
.
device
))[
0
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed
,
seq_len
,
_
=
text_embeddings
.
shape
bs_embed
,
seq_len
,
_
=
text_embeddings
.
shape
...
@@ -342,7 +362,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -342,7 +362,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
truncation
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
device
))[
0
]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len
=
uncond_embeddings
.
shape
[
1
]
seq_len
=
uncond_embeddings
.
shape
[
1
]
...
@@ -362,20 +382,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -362,20 +382,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
self
.
unet
.
in_channels
,
height
//
8
,
width
//
8
)
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
self
.
unet
.
in_channels
,
height
//
8
,
width
//
8
)
latents_dtype
=
text_embeddings
.
dtype
latents_dtype
=
text_embeddings
.
dtype
if
latents
is
None
:
if
latents
is
None
:
if
self
.
device
.
type
==
"mps"
:
if
device
.
type
==
"mps"
:
# randn does not work reproducibly on mps
# randn does not work reproducibly on mps
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
).
to
(
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
).
to
(
device
)
self
.
device
)
else
:
else
:
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
device
,
dtype
=
latents_dtype
)
else
:
else
:
if
latents
.
shape
!=
latents_shape
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
latents
=
latents
.
to
(
self
.
device
)
latents
=
latents
.
to
(
device
)
# set timesteps and move to the correct device
# set timesteps and move to the correct device
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
self
.
device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
timesteps_tensor
=
self
.
scheduler
.
timesteps
timesteps_tensor
=
self
.
scheduler
.
timesteps
# scale the initial noise by the standard deviation required by the scheduler
# scale the initial noise by the standard deviation required by the scheduler
...
@@ -424,9 +442,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -424,9 +442,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
if
self
.
safety_checker
is
not
None
:
if
self
.
safety_checker
is
not
None
:
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
device
)
self
.
device
)
image
,
has_nsfw_concept
=
self
.
safety_checker
(
image
,
has_nsfw_concept
=
self
.
safety_checker
(
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
text_embeddings
.
dtype
)
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
text_embeddings
.
dtype
)
)
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
24895a1f
...
@@ -183,6 +183,25 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -183,6 +183,25 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if
cpu_offloaded_model
is
not
None
:
if
cpu_offloaded_model
is
not
None
:
cpu_offload
(
cpu_offloaded_model
,
device
)
cpu_offload
(
cpu_offloaded_model
,
device
)
@
property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def
_execution_device
(
self
):
r
"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
unet
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
unet
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
def
enable_xformers_memory_efficient_attention
(
self
):
def
enable_xformers_memory_efficient_attention
(
self
):
r
"""
r
"""
Enable memory efficient attention as implemented in xformers.
Enable memory efficient attention as implemented in xformers.
...
@@ -292,6 +311,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -292,6 +311,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
f
"
{
type
(
callback_steps
)
}
."
f
"
{
type
(
callback_steps
)
}
."
)
)
device
=
self
.
_execution_device
# set timesteps
# set timesteps
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
...
@@ -314,7 +335,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -314,7 +335,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
self
.
device
))[
0
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt
# duplicate text embeddings for each generation per prompt
text_embeddings
=
text_embeddings
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
text_embeddings
=
text_embeddings
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
...
@@ -348,7 +369,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -348,7 +369,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
truncation
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
device
))[
0
]
# duplicate unconditional embeddings for each generation per prompt
# duplicate unconditional embeddings for each generation per prompt
seq_len
=
uncond_embeddings
.
shape
[
1
]
seq_len
=
uncond_embeddings
.
shape
[
1
]
...
@@ -362,7 +383,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -362,7 +383,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# encode the init image into latents and scale the latents
# encode the init image into latents and scale the latents
latents_dtype
=
text_embeddings
.
dtype
latents_dtype
=
text_embeddings
.
dtype
init_image
=
init_image
.
to
(
device
=
self
.
device
,
dtype
=
latents_dtype
)
init_image
=
init_image
.
to
(
device
=
device
,
dtype
=
latents_dtype
)
init_latent_dist
=
self
.
vae
.
encode
(
init_image
).
latent_dist
init_latent_dist
=
self
.
vae
.
encode
(
init_image
).
latent_dist
init_latents
=
init_latent_dist
.
sample
(
generator
=
generator
)
init_latents
=
init_latent_dist
.
sample
(
generator
=
generator
)
init_latents
=
0.18215
*
init_latents
init_latents
=
0.18215
*
init_latents
...
@@ -393,10 +414,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -393,10 +414,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_timestep
=
min
(
init_timestep
,
num_inference_steps
)
init_timestep
=
min
(
init_timestep
,
num_inference_steps
)
timesteps
=
self
.
scheduler
.
timesteps
[
-
init_timestep
]
timesteps
=
self
.
scheduler
.
timesteps
[
-
init_timestep
]
timesteps
=
torch
.
tensor
([
timesteps
]
*
batch_size
*
num_images_per_prompt
,
device
=
self
.
device
)
timesteps
=
torch
.
tensor
([
timesteps
]
*
batch_size
*
num_images_per_prompt
,
device
=
device
)
# add noise to latents using the timesteps
# add noise to latents using the timesteps
noise
=
torch
.
randn
(
init_latents
.
shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
noise
=
torch
.
randn
(
init_latents
.
shape
,
generator
=
generator
,
device
=
device
,
dtype
=
latents_dtype
)
init_latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timesteps
)
init_latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timesteps
)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
...
@@ -419,7 +440,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -419,7 +440,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# Some schedulers like PNDM have timesteps as arrays
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
# It's more optimized to move all timesteps to correct device beforehand
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:].
to
(
self
.
device
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:].
to
(
device
)
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
for
i
,
t
in
enumerate
(
self
.
progress_bar
(
timesteps
)):
# expand the latents if we are doing classifier free guidance
# expand the latents if we are doing classifier free guidance
...
@@ -448,9 +469,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -448,9 +469,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
self
.
safety_checker
is
not
None
:
if
self
.
safety_checker
is
not
None
:
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
device
)
self
.
device
)
image
,
has_nsfw_concept
=
self
.
safety_checker
(
image
,
has_nsfw_concept
=
self
.
safety_checker
(
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
text_embeddings
.
dtype
)
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
text_embeddings
.
dtype
)
)
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
24895a1f
...
@@ -183,6 +183,25 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -183,6 +183,25 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if
cpu_offloaded_model
is
not
None
:
if
cpu_offloaded_model
is
not
None
:
cpu_offload
(
cpu_offloaded_model
,
device
)
cpu_offload
(
cpu_offloaded_model
,
device
)
@
property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def
_execution_device
(
self
):
r
"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
unet
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
unet
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
def
enable_xformers_memory_efficient_attention
(
self
):
def
enable_xformers_memory_efficient_attention
(
self
):
r
"""
r
"""
Enable memory efficient attention as implemented in xformers.
Enable memory efficient attention as implemented in xformers.
...
@@ -303,6 +322,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -303,6 +322,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
f
"
{
type
(
callback_steps
)
}
."
f
"
{
type
(
callback_steps
)
}
."
)
)
device
=
self
.
_execution_device
# get prompt text embeddings
# get prompt text embeddings
text_inputs
=
self
.
tokenizer
(
text_inputs
=
self
.
tokenizer
(
prompt
,
prompt
,
...
@@ -319,7 +340,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -319,7 +340,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
f
"
{
self
.
tokenizer
.
model_max_length
}
tokens:
{
removed_text
}
"
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
self
.
device
))[
0
]
text_embeddings
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))[
0
]
# duplicate text embeddings for each generation per prompt, using mps friendly method
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed
,
seq_len
,
_
=
text_embeddings
.
shape
bs_embed
,
seq_len
,
_
=
text_embeddings
.
shape
...
@@ -359,7 +380,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -359,7 +380,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
truncation
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))[
0
]
uncond_embeddings
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
device
))[
0
]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len
=
uncond_embeddings
.
shape
[
1
]
seq_len
=
uncond_embeddings
.
shape
[
1
]
...
@@ -379,17 +400,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -379,17 +400,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
num_channels_latents
,
height
//
8
,
width
//
8
)
latents_shape
=
(
batch_size
*
num_images_per_prompt
,
num_channels_latents
,
height
//
8
,
width
//
8
)
latents_dtype
=
text_embeddings
.
dtype
latents_dtype
=
text_embeddings
.
dtype
if
latents
is
None
:
if
latents
is
None
:
if
self
.
device
.
type
==
"mps"
:
if
device
.
type
==
"mps"
:
# randn does not exist on mps
# randn does not exist on mps
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
).
to
(
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
"cpu"
,
dtype
=
latents_dtype
).
to
(
device
)
self
.
device
)
else
:
else
:
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
self
.
device
,
dtype
=
latents_dtype
)
latents
=
torch
.
randn
(
latents_shape
,
generator
=
generator
,
device
=
device
,
dtype
=
latents_dtype
)
else
:
else
:
if
latents
.
shape
!=
latents_shape
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
latents
=
latents
.
to
(
self
.
device
)
latents
=
latents
.
to
(
device
)
# prepare mask and masked_image
# prepare mask and masked_image
mask
,
masked_image
=
prepare_mask_and_masked_image
(
image
,
mask_image
)
mask
,
masked_image
=
prepare_mask_and_masked_image
(
image
,
mask_image
)
...
@@ -398,9 +417,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -398,9 +417,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
# and half precision
mask
=
torch
.
nn
.
functional
.
interpolate
(
mask
,
size
=
(
height
//
8
,
width
//
8
))
mask
=
torch
.
nn
.
functional
.
interpolate
(
mask
,
size
=
(
height
//
8
,
width
//
8
))
mask
=
mask
.
to
(
device
=
self
.
device
,
dtype
=
text_embeddings
.
dtype
)
mask
=
mask
.
to
(
device
=
device
,
dtype
=
text_embeddings
.
dtype
)
masked_image
=
masked_image
.
to
(
device
=
self
.
device
,
dtype
=
text_embeddings
.
dtype
)
masked_image
=
masked_image
.
to
(
device
=
device
,
dtype
=
text_embeddings
.
dtype
)
# encode the mask image into latents space so we can concatenate it to the latents
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents
=
self
.
vae
.
encode
(
masked_image
).
latent_dist
.
sample
(
generator
=
generator
)
masked_image_latents
=
self
.
vae
.
encode
(
masked_image
).
latent_dist
.
sample
(
generator
=
generator
)
...
@@ -416,7 +435,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -416,7 +435,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
)
)
# aligning device to prevent device errors when concating it with the latent model input
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents
=
masked_image_latents
.
to
(
device
=
self
.
device
,
dtype
=
text_embeddings
.
dtype
)
masked_image_latents
=
masked_image_latents
.
to
(
device
=
device
,
dtype
=
text_embeddings
.
dtype
)
num_channels_mask
=
mask
.
shape
[
1
]
num_channels_mask
=
mask
.
shape
[
1
]
num_channels_masked_image
=
masked_image_latents
.
shape
[
1
]
num_channels_masked_image
=
masked_image_latents
.
shape
[
1
]
...
@@ -431,7 +450,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -431,7 +450,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
)
)
# set timesteps and move to the correct device
# set timesteps and move to the correct device
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
self
.
device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
timesteps_tensor
=
self
.
scheduler
.
timesteps
timesteps_tensor
=
self
.
scheduler
.
timesteps
# scale the initial noise by the standard deviation required by the scheduler
# scale the initial noise by the standard deviation required by the scheduler
...
@@ -484,9 +503,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -484,9 +503,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
float
().
numpy
()
if
self
.
safety_checker
is
not
None
:
if
self
.
safety_checker
is
not
None
:
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
safety_checker_input
=
self
.
feature_extractor
(
self
.
numpy_to_pil
(
image
),
return_tensors
=
"pt"
).
to
(
device
)
self
.
device
)
image
,
has_nsfw_concept
=
self
.
safety_checker
(
image
,
has_nsfw_concept
=
self
.
safety_checker
(
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
text_embeddings
.
dtype
)
images
=
image
,
clip_input
=
safety_checker_input
.
pixel_values
.
to
(
text_embeddings
.
dtype
)
)
)
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion.py
View file @
24895a1f
...
@@ -839,20 +839,22 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
...
@@ -839,20 +839,22 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert
2
*
low_cpu_mem_usage_time
<
normal_load_time
assert
2
*
low_cpu_mem_usage_time
<
normal_load_time
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"This test is supposed to run on GPU"
)
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_max_memory_allocated
()
torch
.
cuda
.
reset_max_memory_allocated
()
torch
.
cuda
.
reset_peak_memory_stats
()
pipeline_id
=
"CompVis/stable-diffusion-v1-4"
pipeline_id
=
"CompVis/stable-diffusion-v1-4"
prompt
=
"Andromeda galaxy in a bottle"
prompt
=
"Andromeda galaxy in a bottle"
pipeline
=
StableDiffusionPipeline
.
from_pretrained
(
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
)
pipeline
=
StableDiffusionPipeline
.
from_pretrained
(
pipeline_id
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
)
pipeline
=
pipeline
.
to
(
torch_device
)
pipeline
.
enable_attention_slicing
(
1
)
pipeline
.
enable_attention_slicing
(
1
)
pipeline
.
enable_sequential_cpu_offload
()
pipeline
.
enable_sequential_cpu_offload
()
_
=
pipeline
(
prompt
,
num_inference_steps
=
5
)
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
_
=
pipeline
(
prompt
,
generator
=
generator
,
num_inference_steps
=
5
)
mem_bytes
=
torch
.
cuda
.
max_memory_allocated
()
mem_bytes
=
torch
.
cuda
.
max_memory_allocated
()
# make sure that less than
1.5
GB is allocated
# make sure that less than
2.8
GB is allocated
assert
mem_bytes
<
1.5
*
10
**
9
assert
mem_bytes
<
2.8
*
10
**
9
tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
View file @
24895a1f
...
@@ -603,25 +603,18 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -603,25 +603,18 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_max_memory_allocated
()
torch
.
cuda
.
reset_max_memory_allocated
()
torch
.
cuda
.
reset_peak_memory_stats
()
init_image
=
load_image
(
init_image
=
load_image
(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
"/img2img/sketch-mountains-input.jpg"
)
)
expected_image
=
load_image
(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/fantasy_landscape_k_lms.png"
)
init_image
=
init_image
.
resize
((
768
,
512
))
init_image
=
init_image
.
resize
((
768
,
512
))
expected_image
=
np
.
array
(
expected_image
,
dtype
=
np
.
float32
)
/
255.0
model_id
=
"CompVis/stable-diffusion-v1-4"
model_id
=
"CompVis/stable-diffusion-v1-4"
lms
=
LMSDiscreteScheduler
.
from_config
(
model_id
,
subfolder
=
"scheduler"
)
lms
=
LMSDiscreteScheduler
.
from_config
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
pipe
=
StableDiffusionImg2ImgPipeline
.
from_pretrained
(
model_id
,
model_id
,
scheduler
=
lms
,
safety_checker
=
None
,
device_map
=
"auto"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
scheduler
=
lms
,
safety_checker
=
None
,
device_map
=
"auto"
,
)
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -642,5 +635,5 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
...
@@ -642,5 +635,5 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
)
mem_bytes
=
torch
.
cuda
.
max_memory_allocated
()
mem_bytes
=
torch
.
cuda
.
max_memory_allocated
()
# make sure that less than
1.5
GB is allocated
# make sure that less than
2.2
GB is allocated
assert
mem_bytes
<
1.5
*
10
**
9
assert
mem_bytes
<
2.2
*
10
**
9
tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
View file @
24895a1f
...
@@ -384,6 +384,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -384,6 +384,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_max_memory_allocated
()
torch
.
cuda
.
reset_max_memory_allocated
()
torch
.
cuda
.
reset_peak_memory_stats
()
init_image
=
load_image
(
init_image
=
load_image
(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
...
@@ -393,16 +394,16 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -393,16 +394,16 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
)
expected_image
=
load_image
(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/yellow_cat_sitting_on_a_park_bench_pndm.png"
)
expected_image
=
np
.
array
(
expected_image
,
dtype
=
np
.
float32
)
/
255.0
model_id
=
"runwayml/stable-diffusion-inpainting"
model_id
=
"runwayml/stable-diffusion-inpainting"
pndm
=
PNDMScheduler
.
from_config
(
model_id
,
subfolder
=
"scheduler"
)
pndm
=
PNDMScheduler
.
from_config
(
model_id
,
subfolder
=
"scheduler"
)
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
pipe
=
StableDiffusionInpaintPipeline
.
from_pretrained
(
model_id
,
safety_checker
=
None
,
scheduler
=
pndm
,
device_map
=
"auto"
model_id
,
safety_checker
=
None
,
scheduler
=
pndm
,
device_map
=
"auto"
,
revision
=
"fp16"
,
torch_dtype
=
torch
.
float16
,
)
)
pipe
.
to
(
torch_device
)
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -422,5 +423,5 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
...
@@ -422,5 +423,5 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
)
mem_bytes
=
torch
.
cuda
.
max_memory_allocated
()
mem_bytes
=
torch
.
cuda
.
max_memory_allocated
()
# make sure that less than
1.5
GB is allocated
# make sure that less than
2.2
GB is allocated
assert
mem_bytes
<
1.5
*
10
**
9
assert
mem_bytes
<
2.2
*
10
**
9
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment