Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b214bb25
Unverified
Commit
b214bb25
authored
Feb 16, 2023
by
Will Berman
Committed by
GitHub
Feb 16, 2023
Browse files
train_text_to_image EMAModel saving (#2341)
parent
de9ce9e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
166 additions
and
4 deletions
+166
-4
examples/test_examples.py
examples/test_examples.py
+165
-3
examples/text_to_image/train_text_to_image.py
examples/text_to_image/train_text_to_image.py
+1
-1
No files found.
examples/test_examples.py
View file @
b214bb25
...
@@ -143,10 +143,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
...
@@ -143,10 +143,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
def
test_dreambooth_checkpointing
(
self
):
def
test_dreambooth_checkpointing
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
instance_prompt
=
"photo"
instance_prompt
=
"photo"
pretrained_model_name_or_path
=
"hf-internal-testing/tiny-stable-diffusion-pipe"
pretrained_model_name_or_path
=
"hf-internal-testing/tiny-stable-diffusion-pipe"
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
# Run training script with checkpointing
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
# Should create checkpoints at steps 2, 4
...
@@ -244,3 +244,165 @@ class ExamplesTestsAccelerate(unittest.TestCase):
...
@@ -244,3 +244,165 @@ class ExamplesTestsAccelerate(unittest.TestCase):
# save_pretrained smoke test
# save_pretrained smoke test
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"unet"
,
"diffusion_pytorch_model.bin"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"unet"
,
"diffusion_pytorch_model.bin"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
self
.
assertTrue
(
os
.
path
.
isfile
(
os
.
path
.
join
(
tmpdir
,
"scheduler"
,
"scheduler_config.json"
)))
def
test_text_to_image_checkpointing
(
self
):
pretrained_model_name_or_path
=
"hf-internal-testing/tiny-stable-diffusion-pipe"
prompt
=
"a prompt"
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args
=
f
"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path
{
pretrained_model_name_or_path
}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir
{
tmpdir
}
--checkpointing_steps=2
--seed=0
"""
.
split
()
run_command
(
self
.
_launch_args
+
initial_run_args
)
pipe
=
DiffusionPipeline
.
from_pretrained
(
tmpdir
,
safety_checker
=
None
)
pipe
(
prompt
,
num_inference_steps
=
2
)
# check checkpoint directories exist
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-2"
)))
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-4"
)))
# check can run an intermediate checkpoint
unet
=
UNet2DConditionModel
.
from_pretrained
(
tmpdir
,
subfolder
=
"checkpoint-2/unet"
)
pipe
=
DiffusionPipeline
.
from_pretrained
(
pretrained_model_name_or_path
,
unet
=
unet
,
safety_checker
=
None
)
pipe
(
prompt
,
num_inference_steps
=
2
)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil
.
rmtree
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-2"
))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args
=
f
"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path
{
pretrained_model_name_or_path
}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir
{
tmpdir
}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--seed=0
"""
.
split
()
run_command
(
self
.
_launch_args
+
resume_run_args
)
# check can run new fully trained pipeline
pipe
=
DiffusionPipeline
.
from_pretrained
(
tmpdir
,
safety_checker
=
None
)
pipe
(
prompt
,
num_inference_steps
=
2
)
# check old checkpoints do not exist
self
.
assertFalse
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-2"
)))
# check new checkpoints exist
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-4"
)))
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-6"
)))
def
test_text_to_image_checkpointing_use_ema
(
self
):
pretrained_model_name_or_path
=
"hf-internal-testing/tiny-stable-diffusion-pipe"
prompt
=
"a prompt"
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args
=
f
"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path
{
pretrained_model_name_or_path
}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir
{
tmpdir
}
--checkpointing_steps=2
--use_ema
--seed=0
"""
.
split
()
run_command
(
self
.
_launch_args
+
initial_run_args
)
pipe
=
DiffusionPipeline
.
from_pretrained
(
tmpdir
,
safety_checker
=
None
)
pipe
(
prompt
,
num_inference_steps
=
2
)
# check checkpoint directories exist
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-2"
)))
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-4"
)))
# check can run an intermediate checkpoint
unet
=
UNet2DConditionModel
.
from_pretrained
(
tmpdir
,
subfolder
=
"checkpoint-2/unet"
)
pipe
=
DiffusionPipeline
.
from_pretrained
(
pretrained_model_name_or_path
,
unet
=
unet
,
safety_checker
=
None
)
pipe
(
prompt
,
num_inference_steps
=
2
)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil
.
rmtree
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-2"
))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args
=
f
"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path
{
pretrained_model_name_or_path
}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir
{
tmpdir
}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--use_ema
--seed=0
"""
.
split
()
run_command
(
self
.
_launch_args
+
resume_run_args
)
# check can run new fully trained pipeline
pipe
=
DiffusionPipeline
.
from_pretrained
(
tmpdir
,
safety_checker
=
None
)
pipe
(
prompt
,
num_inference_steps
=
2
)
# check old checkpoints do not exist
self
.
assertFalse
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-2"
)))
# check new checkpoints exist
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-4"
)))
self
.
assertTrue
(
os
.
path
.
isdir
(
os
.
path
.
join
(
tmpdir
,
"checkpoint-6"
)))
examples/text_to_image/train_text_to_image.py
View file @
b214bb25
...
@@ -413,7 +413,7 @@ def main():
...
@@ -413,7 +413,7 @@ def main():
ema_unet
=
UNet2DConditionModel
.
from_pretrained
(
ema_unet
=
UNet2DConditionModel
.
from_pretrained
(
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
args
.
pretrained_model_name_or_path
,
subfolder
=
"unet"
,
revision
=
args
.
revision
)
)
ema_unet
=
EMAModel
(
ema_unet
.
parameters
())
ema_unet
=
EMAModel
(
ema_unet
.
parameters
()
,
model_cls
=
UNet2DConditionModel
,
model_config
=
ema_unet
.
config
)
if
args
.
enable_xformers_memory_efficient_attention
:
if
args
.
enable_xformers_memory_efficient_attention
:
if
is_xformers_available
():
if
is_xformers_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment