Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8e7d6c03
Unverified
Commit
8e7d6c03
authored
Sep 28, 2024
by
Sayak Paul
Committed by
GitHub
Sep 28, 2024
Browse files
[chore] fix: retain memory utility. (#9543)
* fix: retain memory utility. * fix * quality * free_memory.
parent
b28675c6
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
33 additions
and
35 deletions
+33
-35
examples/cogvideo/train_cogvideox_lora.py
examples/cogvideo/train_cogvideox_lora.py
+3
-5
examples/controlnet/train_controlnet_flux.py
examples/controlnet/train_controlnet_flux.py
+5
-3
examples/controlnet/train_controlnet_sd3.py
examples/controlnet/train_controlnet_sd3.py
+6
-7
examples/dreambooth/train_dreambooth_lora_flux.py
examples/dreambooth/train_dreambooth_lora_flux.py
+7
-4
examples/dreambooth/train_dreambooth_lora_sd3.py
examples/dreambooth/train_dreambooth_lora_sd3.py
+10
-10
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+2
-6
No files found.
examples/cogvideo/train_cogvideox_lora.py
View file @
8e7d6c03
...
@@ -38,10 +38,7 @@ from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPi
...
@@ -38,10 +38,7 @@ from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPi
from
diffusers.models.embeddings
import
get_3d_rotary_pos_embed
from
diffusers.models.embeddings
import
get_3d_rotary_pos_embed
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.pipelines.cogvideo.pipeline_cogvideox
import
get_resize_crop_region_for_grid
from
diffusers.pipelines.cogvideo.pipeline_cogvideox
import
get_resize_crop_region_for_grid
from
diffusers.training_utils
import
(
from
diffusers.training_utils
import
cast_training_params
,
free_memory
cast_training_params
,
clear_objs_and_retain_memory
,
)
from
diffusers.utils
import
check_min_version
,
convert_unet_state_dict_to_peft
,
export_to_video
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
convert_unet_state_dict_to_peft
,
export_to_video
,
is_wandb_available
from
diffusers.utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
diffusers.utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
diffusers.utils.torch_utils
import
is_compiled_module
from
diffusers.utils.torch_utils
import
is_compiled_module
...
@@ -726,7 +723,8 @@ def log_validation(
...
@@ -726,7 +723,8 @@ def log_validation(
}
}
)
)
clear_objs_and_retain_memory
([
pipe
])
del
pipe
free_memory
()
return
videos
return
videos
...
...
examples/controlnet/train_controlnet_flux.py
View file @
8e7d6c03
...
@@ -54,7 +54,7 @@ from diffusers import (
...
@@ -54,7 +54,7 @@ from diffusers import (
from
diffusers.models.controlnet_flux
import
FluxControlNetModel
from
diffusers.models.controlnet_flux
import
FluxControlNetModel
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.pipelines.flux.pipeline_flux_controlnet
import
FluxControlNetPipeline
from
diffusers.pipelines.flux.pipeline_flux_controlnet
import
FluxControlNetPipeline
from
diffusers.training_utils
import
clear_objs_and_retain_memory
,
compute_density_for_timestep_sampling
from
diffusers.training_utils
import
compute_density_for_timestep_sampling
,
free_memory
from
diffusers.utils
import
check_min_version
,
is_wandb_available
,
make_image_grid
from
diffusers.utils
import
check_min_version
,
is_wandb_available
,
make_image_grid
from
diffusers.utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
diffusers.utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
diffusers.utils.import_utils
import
is_torch_npu_available
,
is_xformers_available
from
diffusers.utils.import_utils
import
is_torch_npu_available
,
is_xformers_available
...
@@ -193,7 +193,8 @@ def log_validation(
...
@@ -193,7 +193,8 @@ def log_validation(
else
:
else
:
logger
.
warning
(
f
"image logging not implemented for
{
tracker
.
name
}
"
)
logger
.
warning
(
f
"image logging not implemented for
{
tracker
.
name
}
"
)
clear_objs_and_retain_memory
([
pipeline
])
del
pipeline
free_memory
()
return
image_logs
return
image_logs
...
@@ -1103,7 +1104,8 @@ def main(args):
...
@@ -1103,7 +1104,8 @@ def main(args):
compute_embeddings_fn
,
batched
=
True
,
new_fingerprint
=
new_fingerprint
,
batch_size
=
50
compute_embeddings_fn
,
batched
=
True
,
new_fingerprint
=
new_fingerprint
,
batch_size
=
50
)
)
clear_objs_and_retain_memory
([
text_encoders
,
tokenizers
])
del
text_encoders
,
tokenizers
,
text_encoder_one
,
text_encoder_two
,
tokenizer_one
,
tokenizer_two
free_memory
()
# Then get the training dataset ready to be passed to the dataloader.
# Then get the training dataset ready to be passed to the dataloader.
train_dataset
=
prepare_train_dataset
(
train_dataset
,
accelerator
)
train_dataset
=
prepare_train_dataset
(
train_dataset
,
accelerator
)
...
...
examples/controlnet/train_controlnet_sd3.py
View file @
8e7d6c03
...
@@ -49,11 +49,7 @@ from diffusers import (
...
@@ -49,11 +49,7 @@ from diffusers import (
StableDiffusion3ControlNetPipeline
,
StableDiffusion3ControlNetPipeline
,
)
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
(
from
diffusers.training_utils
import
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
,
free_memory
clear_objs_and_retain_memory
,
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
,
)
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
diffusers.utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
diffusers.utils.torch_utils
import
is_compiled_module
from
diffusers.utils.torch_utils
import
is_compiled_module
...
@@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
...
@@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
else
:
else
:
logger
.
warning
(
f
"image logging not implemented for
{
tracker
.
name
}
"
)
logger
.
warning
(
f
"image logging not implemented for
{
tracker
.
name
}
"
)
clear_objs_and_retain_memory
(
pipeline
)
del
pipeline
free_memory
()
if
not
is_final_validation
:
if
not
is_final_validation
:
controlnet
.
to
(
accelerator
.
device
)
controlnet
.
to
(
accelerator
.
device
)
...
@@ -1131,7 +1128,9 @@ def main(args):
...
@@ -1131,7 +1128,9 @@ def main(args):
new_fingerprint
=
Hasher
.
hash
(
args
)
new_fingerprint
=
Hasher
.
hash
(
args
)
train_dataset
=
train_dataset
.
map
(
compute_embeddings_fn
,
batched
=
True
,
new_fingerprint
=
new_fingerprint
)
train_dataset
=
train_dataset
.
map
(
compute_embeddings_fn
,
batched
=
True
,
new_fingerprint
=
new_fingerprint
)
clear_objs_and_retain_memory
(
text_encoders
+
tokenizers
)
del
text_encoder_one
,
text_encoder_two
,
text_encoder_three
del
tokenizer_one
,
tokenizer_two
,
tokenizer_three
free_memory
()
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
train_dataset
,
...
...
examples/dreambooth/train_dreambooth_lora_flux.py
View file @
8e7d6c03
...
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
...
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
from
diffusers.training_utils
import
(
from
diffusers.training_utils
import
(
_set_state_dict_into_text_encoder
,
_set_state_dict_into_text_encoder
,
cast_training_params
,
cast_training_params
,
clear_objs_and_retain_memory
,
compute_density_for_timestep_sampling
,
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
,
compute_loss_weighting_for_sd3
,
free_memory
,
)
)
from
diffusers.utils
import
(
from
diffusers.utils
import
(
check_min_version
,
check_min_version
,
...
@@ -1437,7 +1437,8 @@ def main(args):
...
@@ -1437,7 +1437,8 @@ def main(args):
# Clear the memory here
# Clear the memory here
if
not
args
.
train_text_encoder
and
not
train_dataset
.
custom_instance_prompts
:
if
not
args
.
train_text_encoder
and
not
train_dataset
.
custom_instance_prompts
:
clear_objs_and_retain_memory
([
tokenizers
,
text_encoders
,
text_encoder_one
,
text_encoder_two
])
del
text_encoder_one
,
text_encoder_two
,
tokenizer_one
,
tokenizer_two
free_memory
()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# pack the statically computed variables appropriately here. This is so that we don't
...
@@ -1480,7 +1481,8 @@ def main(args):
...
@@ -1480,7 +1481,8 @@ def main(args):
latents_cache
.
append
(
vae
.
encode
(
batch
[
"pixel_values"
]).
latent_dist
)
latents_cache
.
append
(
vae
.
encode
(
batch
[
"pixel_values"
]).
latent_dist
)
if
args
.
validation_prompt
is
None
:
if
args
.
validation_prompt
is
None
:
clear_objs_and_retain_memory
([
vae
])
del
vae
free_memory
()
# Scheduler and math around the number of training steps.
# Scheduler and math around the number of training steps.
overrode_max_train_steps
=
False
overrode_max_train_steps
=
False
...
@@ -1817,7 +1819,8 @@ def main(args):
...
@@ -1817,7 +1819,8 @@ def main(args):
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
)
)
if
not
args
.
train_text_encoder
:
if
not
args
.
train_text_encoder
:
clear_objs_and_retain_memory
([
text_encoder_one
,
text_encoder_two
])
del
text_encoder_one
,
text_encoder_two
free_memory
()
# Save the lora layers
# Save the lora layers
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
...
...
examples/dreambooth/train_dreambooth_lora_sd3.py
View file @
8e7d6c03
...
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
...
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
from
diffusers.training_utils
import
(
from
diffusers.training_utils
import
(
_set_state_dict_into_text_encoder
,
_set_state_dict_into_text_encoder
,
cast_training_params
,
cast_training_params
,
clear_objs_and_retain_memory
,
compute_density_for_timestep_sampling
,
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
,
compute_loss_weighting_for_sd3
,
free_memory
,
)
)
from
diffusers.utils
import
(
from
diffusers.utils
import
(
check_min_version
,
check_min_version
,
...
@@ -211,7 +211,8 @@ def log_validation(
...
@@ -211,7 +211,8 @@ def log_validation(
}
}
)
)
clear_objs_and_retain_memory
(
objs
=
[
pipeline
])
del
pipeline
free_memory
()
return
images
return
images
...
@@ -1106,7 +1107,8 @@ def main(args):
...
@@ -1106,7 +1107,8 @@ def main(args):
image_filename
=
class_images_dir
/
f
"
{
example
[
'index'
][
i
]
+
cur_class_images
}
-
{
hash_image
}
.jpg"
image_filename
=
class_images_dir
/
f
"
{
example
[
'index'
][
i
]
+
cur_class_images
}
-
{
hash_image
}
.jpg"
image
.
save
(
image_filename
)
image
.
save
(
image_filename
)
clear_objs_and_retain_memory
(
objs
=
[
pipeline
])
del
pipeline
free_memory
()
# Handle the repository creation
# Handle the repository creation
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
...
@@ -1453,9 +1455,9 @@ def main(args):
...
@@ -1453,9 +1455,9 @@ def main(args):
# Clear the memory here
# Clear the memory here
if
not
args
.
train_text_encoder
and
not
train_dataset
.
custom_instance_prompts
:
if
not
args
.
train_text_encoder
and
not
train_dataset
.
custom_instance_prompts
:
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
clear_objs_and_retain_memory
(
del
tokenizers
,
text_encoders
objs
=
[
tokenizers
,
text_encoders
,
text_encoder_one
,
text_encoder_two
,
text_encoder_three
]
del
text_encoder_one
,
text_encoder_two
,
text_encoder_three
)
free_memory
(
)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# pack the statically computed variables appropriately here. This is so that we don't
...
@@ -1791,11 +1793,9 @@ def main(args):
...
@@ -1791,11 +1793,9 @@ def main(args):
epoch
=
epoch
,
epoch
=
epoch
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
)
)
objs
=
[]
if
not
args
.
train_text_encoder
:
objs
.
extend
([
text_encoder_one
,
text_encoder_two
,
text_encoder_three
])
clear_objs_and_retain_memory
(
objs
=
objs
)
del
text_encoder_one
,
text_encoder_two
,
text_encoder_three
free_memory
()
# Save the lora layers
# Save the lora layers
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
...
...
src/diffusers/training_utils.py
View file @
8e7d6c03
...
@@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
...
@@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return
weighting
return
weighting
def
clear_objs_and_retain_memory
(
objs
:
List
[
Any
]):
def
free_memory
():
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
"""Runs garbage collection. Then clears the cache of the available accelerator."""
if
len
(
objs
)
>=
1
:
for
obj
in
objs
:
del
obj
gc
.
collect
()
gc
.
collect
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment