Unverified Commit b214bb25 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

train_text_to_image EMAModel saving (#2341)

parent de9ce9e9
......@@ -143,10 +143,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
......@@ -244,3 +244,165 @@ class ExamplesTestsAccelerate(unittest.TestCase):
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_text_to_image_checkpointing(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
# check new checkpoints exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
def test_text_to_image_checkpointing_use_ema(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--use_ema
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--use_ema
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
# check new checkpoints exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
......@@ -413,7 +413,7 @@ def main():
ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
ema_unet = EMAModel(ema_unet.parameters())
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment