Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
92f15f5b
Unverified
Commit
92f15f5b
authored
Sep 25, 2023
by
Dhruv Nair
Committed by
GitHub
Sep 25, 2023
Browse files
Model CPU offload fix for BLIPDiffusion (#5174)
cpu offload fix for blip diffusion
parent
22b19d57
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
12 deletions
+29
-12
src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
...users/pipelines/blip_diffusion/pipeline_blip_diffusion.py
+15
-6
src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
...ipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+14
-6
No files found.
src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
View file @
92f15f5b
...
...
@@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
model_cpu_offload_seq
=
"qformer->text_encoder->unet->vae"
def
__init__
(
self
,
tokenizer
:
CLIPTokenizer
,
...
...
@@ -155,7 +157,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
return
latents
def
encode_prompt
(
self
,
query_embeds
,
prompt
):
def
encode_prompt
(
self
,
query_embeds
,
prompt
,
device
=
None
):
device
=
device
or
self
.
_execution_device
# embeddings for prompt, with query_embeds as context
max_len
=
self
.
text_encoder
.
text_model
.
config
.
max_position_embeddings
max_len
-=
self
.
qformer
.
config
.
num_query_tokens
...
...
@@ -166,7 +170,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
truncation
=
True
,
max_length
=
max_len
,
return_tensors
=
"pt"
,
).
to
(
self
.
device
)
).
to
(
device
)
batch_size
=
query_embeds
.
shape
[
0
]
ctx_begin_pos
=
[
self
.
config
.
ctx_begin_pos
]
*
batch_size
...
...
@@ -249,11 +253,12 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device
=
self
.
_execution_device
reference_image
=
self
.
image_processor
.
preprocess
(
reference_image
,
image_mean
=
self
.
config
.
mean
,
image_std
=
self
.
config
.
std
,
return_tensors
=
"pt"
)[
"pixel_values"
]
reference_image
=
reference_image
.
to
(
self
.
device
)
reference_image
=
reference_image
.
to
(
device
)
if
isinstance
(
prompt
,
str
):
prompt
=
[
prompt
]
...
...
@@ -271,7 +276,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
prompt_reps
=
prompt_reps
,
)
query_embeds
=
self
.
get_query_embeddings
(
reference_image
,
source_subject_category
)
text_embeddings
=
self
.
encode_prompt
(
query_embeds
,
prompt
)
text_embeddings
=
self
.
encode_prompt
(
query_embeds
,
prompt
,
device
)
do_classifier_free_guidance
=
guidance_scale
>
1.0
if
do_classifier_free_guidance
:
max_length
=
self
.
text_encoder
.
text_model
.
config
.
max_position_embeddings
...
...
@@ -283,7 +288,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
return_tensors
=
"pt"
,
)
uncond_embeddings
=
self
.
text_encoder
(
input_ids
=
uncond_input
.
input_ids
.
to
(
self
.
device
),
input_ids
=
uncond_input
.
input_ids
.
to
(
device
),
ctx_embeddings
=
None
,
)[
0
]
# For classifier free guidance, we need to do two forward passes.
...
...
@@ -300,7 +305,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
generator
=
generator
,
latents
=
latents
,
dtype
=
self
.
unet
.
dtype
,
device
=
self
.
device
,
device
=
device
,
)
# set timesteps
extra_set_kwargs
=
{}
...
...
@@ -330,9 +335,13 @@ class BlipDiffusionPipeline(DiffusionPipeline):
t
,
latents
,
)[
"prev_sample"
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
# Offload all models
self
.
maybe_free_model_hooks
()
if
not
return_dict
:
return
(
image
,)
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
View file @
92f15f5b
...
...
@@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
model_cpu_offload_seq
=
"qformer->text_encoder->unet->vae"
def
__init__
(
self
,
tokenizer
:
CLIPTokenizer
,
...
...
@@ -166,7 +168,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
return
latents
def
encode_prompt
(
self
,
query_embeds
,
prompt
):
def
encode_prompt
(
self
,
query_embeds
,
prompt
,
device
=
None
):
device
=
device
or
self
.
_execution_device
# embeddings for prompt, with query_embeds as context
max_len
=
self
.
text_encoder
.
text_model
.
config
.
max_position_embeddings
max_len
-=
self
.
qformer
.
config
.
num_query_tokens
...
...
@@ -177,7 +181,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
truncation
=
True
,
max_length
=
max_len
,
return_tensors
=
"pt"
,
).
to
(
self
.
device
)
).
to
(
device
)
batch_size
=
query_embeds
.
shape
[
0
]
ctx_begin_pos
=
[
self
.
config
.
ctx_begin_pos
]
*
batch_size
...
...
@@ -297,11 +301,12 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device
=
self
.
_execution_device
reference_image
=
self
.
image_processor
.
preprocess
(
reference_image
,
image_mean
=
self
.
config
.
mean
,
image_std
=
self
.
config
.
std
,
return_tensors
=
"pt"
)[
"pixel_values"
]
reference_image
=
reference_image
.
to
(
self
.
device
)
reference_image
=
reference_image
.
to
(
device
)
if
isinstance
(
prompt
,
str
):
prompt
=
[
prompt
]
...
...
@@ -319,7 +324,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
prompt_reps
=
prompt_reps
,
)
query_embeds
=
self
.
get_query_embeddings
(
reference_image
,
source_subject_category
)
text_embeddings
=
self
.
encode_prompt
(
query_embeds
,
prompt
)
text_embeddings
=
self
.
encode_prompt
(
query_embeds
,
prompt
,
device
)
# 3. unconditional embedding
do_classifier_free_guidance
=
guidance_scale
>
1.0
if
do_classifier_free_guidance
:
...
...
@@ -332,7 +337,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
return_tensors
=
"pt"
,
)
uncond_embeddings
=
self
.
text_encoder
(
input_ids
=
uncond_input
.
input_ids
.
to
(
self
.
device
),
input_ids
=
uncond_input
.
input_ids
.
to
(
device
),
ctx_embeddings
=
None
,
)[
0
]
# For classifier free guidance, we need to do two forward passes.
...
...
@@ -348,7 +353,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
generator
=
generator
,
latents
=
latents
,
dtype
=
self
.
unet
.
dtype
,
device
=
self
.
device
,
device
=
device
,
)
# set timesteps
extra_set_kwargs
=
{}
...
...
@@ -399,6 +404,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
# Offload all models
self
.
maybe_free_model_hooks
()
if
not
return_dict
:
return
(
image
,)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment