Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8ba90aa7
Unverified
Commit
8ba90aa7
authored
Sep 03, 2024
by
Sayak Paul
Committed by
GitHub
Sep 03, 2024
Browse files
chore: add a cleaning utility to be useful during training. (#9240)
parent
9d49b45b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
15 deletions
+26
-15
examples/dreambooth/train_dreambooth_lora_sd3.py
examples/dreambooth/train_dreambooth_lora_sd3.py
+9
-15
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+17
-0
No files found.
examples/dreambooth/train_dreambooth_lora_sd3.py
View file @
8ba90aa7
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
import
argparse
import
argparse
import
copy
import
copy
import
gc
import
itertools
import
itertools
import
logging
import
logging
import
math
import
math
...
@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler
...
@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler
from
diffusers.training_utils
import
(
from
diffusers.training_utils
import
(
_set_state_dict_into_text_encoder
,
_set_state_dict_into_text_encoder
,
cast_training_params
,
cast_training_params
,
clear_objs_and_retain_memory
,
compute_density_for_timestep_sampling
,
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
,
compute_loss_weighting_for_sd3
,
)
)
...
@@ -210,9 +210,7 @@ def log_validation(
...
@@ -210,9 +210,7 @@ def log_validation(
}
}
)
)
del
pipeline
clear_objs_and_retain_memory
(
objs
=
[
pipeline
])
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
return
images
return
images
...
@@ -1107,9 +1105,7 @@ def main(args):
...
@@ -1107,9 +1105,7 @@ def main(args):
image_filename
=
class_images_dir
/
f
"
{
example
[
'index'
][
i
]
+
cur_class_images
}
-
{
hash_image
}
.jpg"
image_filename
=
class_images_dir
/
f
"
{
example
[
'index'
][
i
]
+
cur_class_images
}
-
{
hash_image
}
.jpg"
image
.
save
(
image_filename
)
image
.
save
(
image_filename
)
del
pipeline
clear_objs_and_retain_memory
(
objs
=
[
pipeline
])
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
# Handle the repository creation
# Handle the repository creation
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
...
@@ -1455,12 +1451,10 @@ def main(args):
...
@@ -1455,12 +1451,10 @@ def main(args):
# Clear the memory here
# Clear the memory here
if
not
args
.
train_text_encoder
and
not
train_dataset
.
custom_instance_prompts
:
if
not
args
.
train_text_encoder
and
not
train_dataset
.
custom_instance_prompts
:
del
tokenizers
,
text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del
text_encoder_one
,
text_encoder_two
,
text_encoder_three
clear_objs_and_retain_memory
(
gc
.
collect
()
objs
=
[
tokenizers
,
text_encoders
,
text_encoder_one
,
text_encoder_two
,
text_encoder_three
]
if
torch
.
cuda
.
is_available
():
)
torch
.
cuda
.
empty_cache
()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# pack the statically computed variables appropriately here. This is so that we don't
...
@@ -1795,11 +1789,11 @@ def main(args):
...
@@ -1795,11 +1789,11 @@ def main(args):
pipeline_args
=
pipeline_args
,
pipeline_args
=
pipeline_args
,
epoch
=
epoch
,
epoch
=
epoch
,
)
)
objs
=
[]
if
not
args
.
train_text_encoder
:
if
not
args
.
train_text_encoder
:
del
text_encoder_one
,
text_encoder_two
,
text_encoder_three
objs
.
extend
([
text_encoder_one
,
text_encoder_two
,
text_encoder_three
])
torch
.
cuda
.
empty_cache
()
clear_objs_and_retain_memory
(
objs
=
objs
)
gc
.
collect
()
# Save the lora layers
# Save the lora layers
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
...
...
src/diffusers/training_utils.py
View file @
8ba90aa7
import
contextlib
import
contextlib
import
copy
import
copy
import
gc
import
math
import
math
import
random
import
random
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
...
@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
...
@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return
weighting
return
weighting
def
clear_objs_and_retain_memory
(
objs
:
List
[
Any
]):
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
if
len
(
objs
)
>=
1
:
for
obj
in
objs
:
del
obj
gc
.
collect
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
elif
torch
.
backends
.
mps
.
is_available
():
torch
.
mps
.
empty_cache
()
elif
is_torch_npu_available
():
torch_npu
.
empty_cache
()
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class
EMAModel
:
class
EMAModel
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment