Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
92f15f5b
Unverified
Commit
92f15f5b
authored
Sep 25, 2023
by
Dhruv Nair
Committed by
GitHub
Sep 25, 2023
Browse files
Model CPU offload fix for BLIPDiffusion (#5174)
cpu offload fix for blip diffusion
parent
22b19d57
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
12 deletions
+29
-12
src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
...users/pipelines/blip_diffusion/pipeline_blip_diffusion.py
+15
-6
src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
...ipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+14
-6
No files found.
src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
View file @
92f15f5b
...
@@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline):
...
@@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
Position of the context token in the text encoder.
"""
"""
model_cpu_offload_seq
=
"qformer->text_encoder->unet->vae"
def
__init__
(
def
__init__
(
self
,
self
,
tokenizer
:
CLIPTokenizer
,
tokenizer
:
CLIPTokenizer
,
...
@@ -155,7 +157,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
...
@@ -155,7 +157,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
return
latents
return
latents
def
encode_prompt
(
self
,
query_embeds
,
prompt
):
def
encode_prompt
(
self
,
query_embeds
,
prompt
,
device
=
None
):
device
=
device
or
self
.
_execution_device
# embeddings for prompt, with query_embeds as context
# embeddings for prompt, with query_embeds as context
max_len
=
self
.
text_encoder
.
text_model
.
config
.
max_position_embeddings
max_len
=
self
.
text_encoder
.
text_model
.
config
.
max_position_embeddings
max_len
-=
self
.
qformer
.
config
.
num_query_tokens
max_len
-=
self
.
qformer
.
config
.
num_query_tokens
...
@@ -166,7 +170,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
...
@@ -166,7 +170,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
truncation
=
True
,
truncation
=
True
,
max_length
=
max_len
,
max_length
=
max_len
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
).
to
(
self
.
device
)
).
to
(
device
)
batch_size
=
query_embeds
.
shape
[
0
]
batch_size
=
query_embeds
.
shape
[
0
]
ctx_begin_pos
=
[
self
.
config
.
ctx_begin_pos
]
*
batch_size
ctx_begin_pos
=
[
self
.
config
.
ctx_begin_pos
]
*
batch_size
...
@@ -249,11 +253,12 @@ class BlipDiffusionPipeline(DiffusionPipeline):
...
@@ -249,11 +253,12 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Returns:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
"""
device
=
self
.
_execution_device
reference_image
=
self
.
image_processor
.
preprocess
(
reference_image
=
self
.
image_processor
.
preprocess
(
reference_image
,
image_mean
=
self
.
config
.
mean
,
image_std
=
self
.
config
.
std
,
return_tensors
=
"pt"
reference_image
,
image_mean
=
self
.
config
.
mean
,
image_std
=
self
.
config
.
std
,
return_tensors
=
"pt"
)[
"pixel_values"
]
)[
"pixel_values"
]
reference_image
=
reference_image
.
to
(
self
.
device
)
reference_image
=
reference_image
.
to
(
device
)
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
prompt
=
[
prompt
]
prompt
=
[
prompt
]
...
@@ -271,7 +276,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
...
@@ -271,7 +276,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
prompt_reps
=
prompt_reps
,
prompt_reps
=
prompt_reps
,
)
)
query_embeds
=
self
.
get_query_embeddings
(
reference_image
,
source_subject_category
)
query_embeds
=
self
.
get_query_embeddings
(
reference_image
,
source_subject_category
)
text_embeddings
=
self
.
encode_prompt
(
query_embeds
,
prompt
)
text_embeddings
=
self
.
encode_prompt
(
query_embeds
,
prompt
,
device
)
do_classifier_free_guidance
=
guidance_scale
>
1.0
do_classifier_free_guidance
=
guidance_scale
>
1.0
if
do_classifier_free_guidance
:
if
do_classifier_free_guidance
:
max_length
=
self
.
text_encoder
.
text_model
.
config
.
max_position_embeddings
max_length
=
self
.
text_encoder
.
text_model
.
config
.
max_position_embeddings
...
@@ -283,7 +288,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
...
@@ -283,7 +288,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_embeddings
=
self
.
text_encoder
(
input_ids
=
uncond_input
.
input_ids
.
to
(
self
.
device
),
input_ids
=
uncond_input
.
input_ids
.
to
(
device
),
ctx_embeddings
=
None
,
ctx_embeddings
=
None
,
)[
0
]
)[
0
]
# For classifier free guidance, we need to do two forward passes.
# For classifier free guidance, we need to do two forward passes.
...
@@ -300,7 +305,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
...
@@ -300,7 +305,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
generator
=
generator
,
generator
=
generator
,
latents
=
latents
,
latents
=
latents
,
dtype
=
self
.
unet
.
dtype
,
dtype
=
self
.
unet
.
dtype
,
device
=
self
.
device
,
device
=
device
,
)
)
# set timesteps
# set timesteps
extra_set_kwargs
=
{}
extra_set_kwargs
=
{}
...
@@ -330,9 +335,13 @@ class BlipDiffusionPipeline(DiffusionPipeline):
...
@@ -330,9 +335,13 @@ class BlipDiffusionPipeline(DiffusionPipeline):
t
,
t
,
latents
,
latents
,
)[
"prev_sample"
]
)[
"prev_sample"
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
# Offload all models
self
.
maybe_free_model_hooks
()
if
not
return_dict
:
if
not
return_dict
:
return
(
image
,)
return
(
image
,)
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
View file @
92f15f5b
...
@@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
...
@@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
Position of the context token in the text encoder.
"""
"""
model_cpu_offload_seq
=
"qformer->text_encoder->unet->vae"
def
__init__
(
def
__init__
(
self
,
self
,
tokenizer
:
CLIPTokenizer
,
tokenizer
:
CLIPTokenizer
,
...
@@ -166,7 +168,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
...
@@ -166,7 +168,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
latents
=
latents
*
self
.
scheduler
.
init_noise_sigma
return
latents
return
latents
def
encode_prompt
(
self
,
query_embeds
,
prompt
):
def
encode_prompt
(
self
,
query_embeds
,
prompt
,
device
=
None
):
device
=
device
or
self
.
_execution_device
# embeddings for prompt, with query_embeds as context
# embeddings for prompt, with query_embeds as context
max_len
=
self
.
text_encoder
.
text_model
.
config
.
max_position_embeddings
max_len
=
self
.
text_encoder
.
text_model
.
config
.
max_position_embeddings
max_len
-=
self
.
qformer
.
config
.
num_query_tokens
max_len
-=
self
.
qformer
.
config
.
num_query_tokens
...
@@ -177,7 +181,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
...
@@ -177,7 +181,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
truncation
=
True
,
truncation
=
True
,
max_length
=
max_len
,
max_length
=
max_len
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
).
to
(
self
.
device
)
).
to
(
device
)
batch_size
=
query_embeds
.
shape
[
0
]
batch_size
=
query_embeds
.
shape
[
0
]
ctx_begin_pos
=
[
self
.
config
.
ctx_begin_pos
]
*
batch_size
ctx_begin_pos
=
[
self
.
config
.
ctx_begin_pos
]
*
batch_size
...
@@ -297,11 +301,12 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
...
@@ -297,11 +301,12 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Returns:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
"""
device
=
self
.
_execution_device
reference_image
=
self
.
image_processor
.
preprocess
(
reference_image
=
self
.
image_processor
.
preprocess
(
reference_image
,
image_mean
=
self
.
config
.
mean
,
image_std
=
self
.
config
.
std
,
return_tensors
=
"pt"
reference_image
,
image_mean
=
self
.
config
.
mean
,
image_std
=
self
.
config
.
std
,
return_tensors
=
"pt"
)[
"pixel_values"
]
)[
"pixel_values"
]
reference_image
=
reference_image
.
to
(
self
.
device
)
reference_image
=
reference_image
.
to
(
device
)
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
prompt
=
[
prompt
]
prompt
=
[
prompt
]
...
@@ -319,7 +324,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
...
@@ -319,7 +324,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
prompt_reps
=
prompt_reps
,
prompt_reps
=
prompt_reps
,
)
)
query_embeds
=
self
.
get_query_embeddings
(
reference_image
,
source_subject_category
)
query_embeds
=
self
.
get_query_embeddings
(
reference_image
,
source_subject_category
)
text_embeddings
=
self
.
encode_prompt
(
query_embeds
,
prompt
)
text_embeddings
=
self
.
encode_prompt
(
query_embeds
,
prompt
,
device
)
# 3. unconditional embedding
# 3. unconditional embedding
do_classifier_free_guidance
=
guidance_scale
>
1.0
do_classifier_free_guidance
=
guidance_scale
>
1.0
if
do_classifier_free_guidance
:
if
do_classifier_free_guidance
:
...
@@ -332,7 +337,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
...
@@ -332,7 +337,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
uncond_embeddings
=
self
.
text_encoder
(
uncond_embeddings
=
self
.
text_encoder
(
input_ids
=
uncond_input
.
input_ids
.
to
(
self
.
device
),
input_ids
=
uncond_input
.
input_ids
.
to
(
device
),
ctx_embeddings
=
None
,
ctx_embeddings
=
None
,
)[
0
]
)[
0
]
# For classifier free guidance, we need to do two forward passes.
# For classifier free guidance, we need to do two forward passes.
...
@@ -348,7 +353,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
...
@@ -348,7 +353,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
generator
=
generator
,
generator
=
generator
,
latents
=
latents
,
latents
=
latents
,
dtype
=
self
.
unet
.
dtype
,
dtype
=
self
.
unet
.
dtype
,
device
=
self
.
device
,
device
=
device
,
)
)
# set timesteps
# set timesteps
extra_set_kwargs
=
{}
extra_set_kwargs
=
{}
...
@@ -399,6 +404,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
...
@@ -399,6 +404,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
vae
.
decode
(
latents
/
self
.
vae
.
config
.
scaling_factor
,
return_dict
=
False
)[
0
]
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
image
=
self
.
image_processor
.
postprocess
(
image
,
output_type
=
output_type
)
# Offload all models
self
.
maybe_free_model_hooks
()
if
not
return_dict
:
if
not
return_dict
:
return
(
image
,)
return
(
image
,)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment