Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c7b4acfb
Unverified
Commit
c7b4acfb
authored
Dec 19, 2022
by
Anton Lozhkov
Committed by
GitHub
Dec 19, 2022
Browse files
Add CPU offloading to UnCLIP (#1761)
* Add CPU offloading to UnCLIP * use fp32 for testing the offload
parent
be38b2d7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
15 deletions
+85
-15
src/diffusers/pipelines/unclip/pipeline_unclip.py
src/diffusers/pipelines/unclip/pipeline_unclip.py
+57
-13
tests/pipelines/unclip/test_unclip.py
tests/pipelines/unclip/test_unclip.py
+28
-2
No files found.
src/diffusers/pipelines/unclip/pipeline_unclip.py
View file @
c7b4acfb
...
@@ -23,7 +23,7 @@ from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...
@@ -23,7 +23,7 @@ from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from
diffusers.schedulers
import
UnCLIPScheduler
from
diffusers.schedulers
import
UnCLIPScheduler
from
transformers
import
CLIPTextModelWithProjection
,
CLIPTokenizer
from
transformers
import
CLIPTextModelWithProjection
,
CLIPTokenizer
from
...utils
import
logging
from
...utils
import
is_accelerate_available
,
logging
from
.text_proj
import
UnCLIPTextProjModel
from
.text_proj
import
UnCLIPTextProjModel
...
@@ -115,7 +115,7 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -115,7 +115,7 @@ class UnCLIPPipeline(DiffusionPipeline):
latents
=
latents
*
scheduler
.
init_noise_sigma
latents
=
latents
*
scheduler
.
init_noise_sigma
return
latents
return
latents
def
_encode_prompt
(
self
,
prompt
,
num_images_per_prompt
,
do_classifier_free_guidance
):
def
_encode_prompt
(
self
,
prompt
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
):
batch_size
=
len
(
prompt
)
if
isinstance
(
prompt
,
list
)
else
1
batch_size
=
len
(
prompt
)
if
isinstance
(
prompt
,
list
)
else
1
# get prompt text embeddings
# get prompt text embeddings
...
@@ -126,7 +126,7 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -126,7 +126,7 @@ class UnCLIPPipeline(DiffusionPipeline):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
text_input_ids
=
text_inputs
.
input_ids
text_input_ids
=
text_inputs
.
input_ids
text_mask
=
text_inputs
.
attention_mask
.
bool
().
to
(
self
.
device
)
text_mask
=
text_inputs
.
attention_mask
.
bool
().
to
(
device
)
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
if
text_input_ids
.
shape
[
-
1
]
>
self
.
tokenizer
.
model_max_length
:
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input_ids
[:,
self
.
tokenizer
.
model_max_length
:])
removed_text
=
self
.
tokenizer
.
batch_decode
(
text_input_ids
[:,
self
.
tokenizer
.
model_max_length
:])
...
@@ -136,7 +136,7 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -136,7 +136,7 @@ class UnCLIPPipeline(DiffusionPipeline):
)
)
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_input_ids
=
text_input_ids
[:,
:
self
.
tokenizer
.
model_max_length
]
text_encoder_output
=
self
.
text_encoder
(
text_input_ids
.
to
(
self
.
device
))
text_encoder_output
=
self
.
text_encoder
(
text_input_ids
.
to
(
device
))
text_embeddings
=
text_encoder_output
.
text_embeds
text_embeddings
=
text_encoder_output
.
text_embeds
text_encoder_hidden_states
=
text_encoder_output
.
last_hidden_state
text_encoder_hidden_states
=
text_encoder_output
.
last_hidden_state
...
@@ -156,8 +156,8 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -156,8 +156,8 @@ class UnCLIPPipeline(DiffusionPipeline):
truncation
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
uncond_text_mask
=
uncond_input
.
attention_mask
.
bool
().
to
(
self
.
device
)
uncond_text_mask
=
uncond_input
.
attention_mask
.
bool
().
to
(
device
)
uncond_embeddings_text_encoder_output
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
self
.
device
))
uncond_embeddings_text_encoder_output
=
self
.
text_encoder
(
uncond_input
.
input_ids
.
to
(
device
))
uncond_embeddings
=
uncond_embeddings_text_encoder_output
.
text_embeds
uncond_embeddings
=
uncond_embeddings_text_encoder_output
.
text_embeds
uncond_text_encoder_hidden_states
=
uncond_embeddings_text_encoder_output
.
last_hidden_state
uncond_text_encoder_hidden_states
=
uncond_embeddings_text_encoder_output
.
last_hidden_state
...
@@ -187,6 +187,49 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -187,6 +187,49 @@ class UnCLIPPipeline(DiffusionPipeline):
return
text_embeddings
,
text_encoder_hidden_states
,
text_mask
return
text_embeddings
,
text_encoder_hidden_states
,
text_mask
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
r
"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if
is_accelerate_available
():
from
accelerate
import
cpu_offload
else
:
raise
ImportError
(
"Please install accelerate via `pip install accelerate`"
)
device
=
torch
.
device
(
f
"cuda:
{
gpu_id
}
"
)
# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
models
=
[
self
.
decoder
,
self
.
text_proj
,
self
.
text_encoder
,
self
.
super_res_first
,
self
.
super_res_last
,
]
for
cpu_offloaded_model
in
models
:
if
cpu_offloaded_model
is
not
None
:
cpu_offload
(
cpu_offloaded_model
,
device
)
@
property
def
_execution_device
(
self
):
r
"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if
self
.
device
!=
torch
.
device
(
"meta"
)
or
not
hasattr
(
self
.
decoder
,
"_hf_hook"
):
return
self
.
device
for
module
in
self
.
decoder
.
modules
():
if
(
hasattr
(
module
,
"_hf_hook"
)
and
hasattr
(
module
.
_hf_hook
,
"execution_device"
)
and
module
.
_hf_hook
.
execution_device
is
not
None
):
return
torch
.
device
(
module
.
_hf_hook
.
execution_device
)
return
self
.
device
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
self
,
self
,
...
@@ -254,25 +297,26 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -254,25 +297,26 @@ class UnCLIPPipeline(DiffusionPipeline):
batch_size
=
len
(
prompt
)
batch_size
=
len
(
prompt
)
else
:
else
:
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
device
=
self
.
_execution_device
batch_size
=
batch_size
*
num_images_per_prompt
batch_size
=
batch_size
*
num_images_per_prompt
do_classifier_free_guidance
=
prior_guidance_scale
>
1.0
or
decoder_guidance_scale
>
1.0
do_classifier_free_guidance
=
prior_guidance_scale
>
1.0
or
decoder_guidance_scale
>
1.0
text_embeddings
,
text_encoder_hidden_states
,
text_mask
=
self
.
_encode_prompt
(
text_embeddings
,
text_encoder_hidden_states
,
text_mask
=
self
.
_encode_prompt
(
prompt
,
num_images_per_prompt
,
do_classifier_free_guidance
prompt
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
)
)
# prior
# prior
self
.
prior_scheduler
.
set_timesteps
(
prior_num_inference_steps
,
device
=
self
.
device
)
self
.
prior_scheduler
.
set_timesteps
(
prior_num_inference_steps
,
device
=
device
)
prior_timesteps_tensor
=
self
.
prior_scheduler
.
timesteps
prior_timesteps_tensor
=
self
.
prior_scheduler
.
timesteps
embedding_dim
=
self
.
prior
.
config
.
embedding_dim
embedding_dim
=
self
.
prior
.
config
.
embedding_dim
prior_latents
=
self
.
prepare_latents
(
prior_latents
=
self
.
prepare_latents
(
(
batch_size
,
embedding_dim
),
(
batch_size
,
embedding_dim
),
text_embeddings
.
dtype
,
text_embeddings
.
dtype
,
self
.
device
,
device
,
generator
,
generator
,
prior_latents
,
prior_latents
,
self
.
prior_scheduler
,
self
.
prior_scheduler
,
...
@@ -326,7 +370,7 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -326,7 +370,7 @@ class UnCLIPPipeline(DiffusionPipeline):
decoder_text_mask
=
F
.
pad
(
text_mask
,
(
self
.
text_proj
.
clip_extra_context_tokens
,
0
),
value
=
1
)
decoder_text_mask
=
F
.
pad
(
text_mask
,
(
self
.
text_proj
.
clip_extra_context_tokens
,
0
),
value
=
1
)
self
.
decoder_scheduler
.
set_timesteps
(
decoder_num_inference_steps
,
device
=
self
.
device
)
self
.
decoder_scheduler
.
set_timesteps
(
decoder_num_inference_steps
,
device
=
device
)
decoder_timesteps_tensor
=
self
.
decoder_scheduler
.
timesteps
decoder_timesteps_tensor
=
self
.
decoder_scheduler
.
timesteps
num_channels_latents
=
self
.
decoder
.
in_channels
num_channels_latents
=
self
.
decoder
.
in_channels
...
@@ -335,7 +379,7 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -335,7 +379,7 @@ class UnCLIPPipeline(DiffusionPipeline):
decoder_latents
=
self
.
prepare_latents
(
decoder_latents
=
self
.
prepare_latents
(
(
batch_size
,
num_channels_latents
,
height
,
width
),
(
batch_size
,
num_channels_latents
,
height
,
width
),
text_encoder_hidden_states
.
dtype
,
text_encoder_hidden_states
.
dtype
,
self
.
device
,
device
,
generator
,
generator
,
decoder_latents
,
decoder_latents
,
self
.
decoder_scheduler
,
self
.
decoder_scheduler
,
...
@@ -378,7 +422,7 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -378,7 +422,7 @@ class UnCLIPPipeline(DiffusionPipeline):
# super res
# super res
self
.
super_res_scheduler
.
set_timesteps
(
super_res_num_inference_steps
,
device
=
self
.
device
)
self
.
super_res_scheduler
.
set_timesteps
(
super_res_num_inference_steps
,
device
=
device
)
super_res_timesteps_tensor
=
self
.
super_res_scheduler
.
timesteps
super_res_timesteps_tensor
=
self
.
super_res_scheduler
.
timesteps
channels
=
self
.
super_res_first
.
in_channels
//
2
channels
=
self
.
super_res_first
.
in_channels
//
2
...
@@ -387,7 +431,7 @@ class UnCLIPPipeline(DiffusionPipeline):
...
@@ -387,7 +431,7 @@ class UnCLIPPipeline(DiffusionPipeline):
super_res_latents
=
self
.
prepare_latents
(
super_res_latents
=
self
.
prepare_latents
(
(
batch_size
,
channels
,
height
,
width
),
(
batch_size
,
channels
,
height
,
width
),
image_small
.
dtype
,
image_small
.
dtype
,
self
.
device
,
device
,
generator
,
generator
,
super_res_latents
,
super_res_latents
,
self
.
super_res_scheduler
,
self
.
super_res_scheduler
,
...
...
tests/pipelines/unclip/test_unclip.py
View file @
c7b4acfb
...
@@ -261,10 +261,10 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
...
@@ -261,10 +261,10 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
def
test_unclip_karlo
(
self
):
def
test_unclip_karlo
(
self
):
expected_image
=
load_numpy
(
expected_image
=
load_numpy
(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/unclip/karlo_v1_alpha_horse.npy"
"/unclip/karlo_v1_alpha_horse
_fp16
.npy"
)
)
pipeline
=
UnCLIPPipeline
.
from_pretrained
(
"kakaobrain/karlo-v1-alpha"
)
pipeline
=
UnCLIPPipeline
.
from_pretrained
(
"kakaobrain/karlo-v1-alpha"
,
torch_dtype
=
torch
.
float16
)
pipeline
=
pipeline
.
to
(
torch_device
)
pipeline
=
pipeline
.
to
(
torch_device
)
pipeline
.
set_progress_bar_config
(
disable
=
None
)
pipeline
.
set_progress_bar_config
(
disable
=
None
)
...
@@ -280,3 +280,29 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
...
@@ -280,3 +280,29 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
assert
image
.
shape
==
(
256
,
256
,
3
)
assert
image
.
shape
==
(
256
,
256
,
3
)
assert
np
.
abs
(
expected_image
-
image
).
max
()
<
1e-2
assert
np
.
abs
(
expected_image
-
image
).
max
()
<
1e-2
def
test_stable_diffusion_pipeline_with_sequential_cpu_offloading
(
self
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_max_memory_allocated
()
torch
.
cuda
.
reset_peak_memory_stats
()
pipe
=
UnCLIPPipeline
.
from_pretrained
(
"kakaobrain/karlo-v1-alpha"
)
pipe
=
pipe
.
to
(
torch_device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
enable_attention_slicing
()
pipe
.
enable_sequential_cpu_offload
()
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
_
=
pipe
(
"horse"
,
num_images_per_prompt
=
1
,
generator
=
generator
,
prior_num_inference_steps
=
2
,
decoder_num_inference_steps
=
2
,
super_res_num_inference_steps
=
2
,
output_type
=
"np"
,
)
mem_bytes
=
torch
.
cuda
.
max_memory_allocated
()
# make sure that less than 1.5 GB is allocated
assert
mem_bytes
<
1.5
*
10
**
9
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment