Unverified Commit 4909b1e3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] fix checkpointing and casting bugs in `train_text_to_image_lora_sdxl.py` (#4632)

* fix: casting issues.

* fix checkpointing.

* tests

* fix: bugs
parent 052bf328
...@@ -828,6 +828,87 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -828,6 +828,87 @@ class ExamplesTestsAccelerate(unittest.TestCase):
{"checkpoint-4", "checkpoint-6"}, {"checkpoint-4", "checkpoint-6"},
) )
def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
prompt = "a prompt"
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
prompt = "a prompt"
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--train_text_encoder
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt" prompt = "a prompt"
......
...@@ -396,16 +396,6 @@ def parse_args(input_args=None): ...@@ -396,16 +396,6 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
), ),
) )
parser.add_argument(
"--prior_generation_precision",
type=str,
default=None,
choices=["no", "fp32", "fp16", "bf16"],
help=(
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument( parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
...@@ -724,11 +714,15 @@ def main(args): ...@@ -724,11 +714,15 @@ def main(args):
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder( LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
) )
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder( LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
) )
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
...@@ -1002,9 +996,12 @@ def main(args): ...@@ -1002,9 +996,12 @@ def main(args):
continue continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
# Convert images to latent space # Convert images to latent space
if args.pretrained_vae_model_name_or_path is not None:
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
else:
pixel_values = batch["pixel_values"]
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor model_input = model_input * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None: if args.pretrained_vae_model_name_or_path is None:
...@@ -1147,13 +1144,6 @@ def main(args): ...@@ -1147,13 +1144,6 @@ def main(args):
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
# create pipeline # create pipeline
if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment