Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
45b42d12
Unverified
Commit
45b42d12
authored
Mar 27, 2024
by
Disty0
Committed by
GitHub
Mar 26, 2024
Browse files
Add device arg to offloading with combined pipelines (#7471)
Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
5199ee4f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
24 deletions
+24
-24
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
.../pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
+12
-12
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
...elines/stable_cascade/pipeline_stable_cascade_combined.py
+6
-6
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
...sers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
+6
-6
No files found.
src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
View file @
45b42d12
...
...
@@ -178,7 +178,7 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
def
enable_xformers_memory_efficient_attention
(
self
,
attention_op
:
Optional
[
Callable
]
=
None
):
self
.
decoder_pipe
.
enable_xformers_memory_efficient_attention
(
attention_op
)
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
def
enable_sequential_cpu_offload
(
self
,
gpu_id
:
Optional
[
int
]
=
None
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
):
r
"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
...
...
@@ -186,8 +186,8 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
def
progress_bar
(
self
,
iterable
=
None
,
total
=
None
):
self
.
prior_pipe
.
progress_bar
(
iterable
=
iterable
,
total
=
total
)
...
...
@@ -405,17 +405,17 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
def
enable_xformers_memory_efficient_attention
(
self
,
attention_op
:
Optional
[
Callable
]
=
None
):
self
.
decoder_pipe
.
enable_xformers_memory_efficient_attention
(
attention_op
)
def
enable_model_cpu_offload
(
self
,
gpu_id
=
0
):
def
enable_model_cpu_offload
(
self
,
gpu_id
:
Optional
[
int
]
=
None
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
):
r
"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
self
.
prior_pipe
.
enable_model_cpu_offload
()
self
.
decoder_pipe
.
enable_model_cpu_offload
()
self
.
prior_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
self
.
decoder_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
def
enable_sequential_cpu_offload
(
self
,
gpu_id
:
Optional
[
int
]
=
None
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
):
r
"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
...
...
@@ -423,8 +423,8 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
def
progress_bar
(
self
,
iterable
=
None
,
total
=
None
):
self
.
prior_pipe
.
progress_bar
(
iterable
=
iterable
,
total
=
total
)
...
...
@@ -653,7 +653,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
def
enable_xformers_memory_efficient_attention
(
self
,
attention_op
:
Optional
[
Callable
]
=
None
):
self
.
decoder_pipe
.
enable_xformers_memory_efficient_attention
(
attention_op
)
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
def
enable_sequential_cpu_offload
(
self
,
gpu_id
:
Optional
[
int
]
=
None
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
):
r
"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
...
...
@@ -661,8 +661,8 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
def
progress_bar
(
self
,
iterable
=
None
,
total
=
None
):
self
.
prior_pipe
.
progress_bar
(
iterable
=
iterable
,
total
=
total
)
...
...
src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
View file @
45b42d12
...
...
@@ -117,25 +117,25 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
def
enable_xformers_memory_efficient_attention
(
self
,
attention_op
:
Optional
[
Callable
]
=
None
):
self
.
decoder_pipe
.
enable_xformers_memory_efficient_attention
(
attention_op
)
def
enable_model_cpu_offload
(
self
,
gpu_id
=
0
):
def
enable_model_cpu_offload
(
self
,
gpu_id
:
Optional
[
int
]
=
None
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
):
r
"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
self
.
prior_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
decoder_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
prior_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
self
.
decoder_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
def
enable_sequential_cpu_offload
(
self
,
gpu_id
:
Optional
[
int
]
=
None
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
):
r
"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
"""
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
def
progress_bar
(
self
,
iterable
=
None
,
total
=
None
):
self
.
prior_pipe
.
progress_bar
(
iterable
=
iterable
,
total
=
total
)
...
...
src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
View file @
45b42d12
...
...
@@ -112,25 +112,25 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
def
enable_xformers_memory_efficient_attention
(
self
,
attention_op
:
Optional
[
Callable
]
=
None
):
self
.
decoder_pipe
.
enable_xformers_memory_efficient_attention
(
attention_op
)
def
enable_model_cpu_offload
(
self
,
gpu_id
=
0
):
def
enable_model_cpu_offload
(
self
,
gpu_id
:
Optional
[
int
]
=
None
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
):
r
"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
self
.
prior_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
decoder_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
prior_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
self
.
decoder_pipe
.
enable_model_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
def
enable_sequential_cpu_offload
(
self
,
gpu_id
=
0
):
def
enable_sequential_cpu_offload
(
self
,
gpu_id
:
Optional
[
int
]
=
None
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
):
r
"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
"""
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
)
self
.
prior_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
self
.
decoder_pipe
.
enable_sequential_cpu_offload
(
gpu_id
=
gpu_id
,
device
=
device
)
def
progress_bar
(
self
,
iterable
=
None
,
total
=
None
):
self
.
prior_pipe
.
progress_bar
(
iterable
=
iterable
,
total
=
total
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment