Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9e8ee2ac
Unverified
Commit
9e8ee2ac
authored
Feb 13, 2023
by
Will Berman
Committed by
GitHub
Feb 13, 2023
Browse files
dreambooth checkpointing tests and docs (#2339)
parent
6782b70d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
86 additions
and
3 deletions
+86
-3
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+5
-3
examples/test_examples.py
examples/test_examples.py
+81
-0
No files found.
examples/dreambooth/train_dreambooth.py
View file @
9e8ee2ac
...
...
@@ -188,9 +188,11 @@ def parse_args(input_args=None):
type
=
int
,
default
=
500
,
help
=
(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
"instructions."
),
)
parser
.
add_argument
(
...
...
examples/test_examples.py
View file @
9e8ee2ac
...
...
@@ -25,6 +25,8 @@ from typing import List
from
accelerate.utils
import
write_basic_config
from
diffusers
import
DiffusionPipeline
,
UNet2DConditionModel
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
...
@@ -140,6 +142,85 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"unet"
,
"diffusion_pytorch_model.bin"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
def
test_dreambooth_checkpointing
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
instance_prompt
=
"photo"
pretrained_model_name_or_path
=
"hf-internal-testing/tiny-stable-diffusion-pipe"
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args
=
f
"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path
{
pretrained_model_name_or_path
}
--instance_data_dir docs/source/en/imgs
--instance_prompt
{
instance_prompt
}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir
{
tmpdir
}
--checkpointing_steps=2
--seed=0
"""
.
split
()
run_command
(
self
.
_launch_args
+
initial_run_args
)
# check can run the original fully trained output pipeline
pipe
=
DiffusionPipeline
.
from_pretrained
(
tmpdir
,
safety_checker
=
None
)
pipe
(
instance_prompt
,
num_inference_steps
=
2
)
# check checkpoint directories exist
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-2"
)))
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-4"
)))
# check can run an intermediate checkpoint
unet
=
UNet2DConditionModel
.
from_pretrained
(
tmpdir
,
subfolder
=
"checkpoint-2/unet"
)
pipe
=
DiffusionPipeline
.
from_pretrained
(
pretrained_model_name_or_path
,
unet
=
unet
,
safety_checker
=
None
)
pipe
(
instance_prompt
,
num_inference_steps
=
2
)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil
.
rmtree
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-2"
))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args
=
f
"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path
{
pretrained_model_name_or_path
}
--instance_data_dir docs/source/en/imgs
--instance_prompt
{
instance_prompt
}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir
{
tmpdir
}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--seed=0
"""
.
split
()
run_command
(
self
.
_launch_args
+
resume_run_args
)
# check can run new fully trained pipeline
pipe
=
DiffusionPipeline
.
from_pretrained
(
tmpdir
,
safety_checker
=
None
)
pipe
(
instance_prompt
,
num_inference_steps
=
2
)
# check old checkpoints do not exist
self
.
assertFalse
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-2"
)))
# check new checkpoints exist
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-4"
)))
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-6"
)))
def
test_text_to_image
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
test_args
=
f
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment