"git@developer.sourcefind.cn:change/sglang.git" did not exist on "c17c57810891591b3f7d5151d65b1e8d13af50f9"
Unverified Commit d49e2dd5 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

manual check for checkpoints_total_limit instead of using accelerate (#3681)

* manual check for checkpoints_total_limit instead of using accelerate

* remove controlnet_conditioning_embedding_out_channels
parent 7bfd2375
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -307,11 +308,7 @@ def parse_args(input_args=None): ...@@ -307,11 +308,7 @@ def parse_args(input_args=None):
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more details"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -716,9 +713,7 @@ def collate_fn(examples): ...@@ -716,9 +713,7 @@ def collate_fn(examples):
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -1060,6 +1055,26 @@ def main(args): ...@@ -1060,6 +1055,26 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -446,11 +447,7 @@ def parse_args(input_args=None): ...@@ -446,11 +447,7 @@ def parse_args(input_args=None):
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -637,9 +634,7 @@ def parse_args(input_args=None): ...@@ -637,9 +634,7 @@ def parse_args(input_args=None):
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -1171,6 +1166,26 @@ def main(args): ...@@ -1171,6 +1166,26 @@ def main(args):
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -20,6 +20,7 @@ import itertools ...@@ -20,6 +20,7 @@ import itertools
import logging import logging
import math import math
import os import os
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -771,9 +772,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte ...@@ -771,9 +772,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -1270,12 +1269,33 @@ def main(args): ...@@ -1270,12 +1269,33 @@ def main(args):
global_step += 1 global_step += 1
if accelerator.is_main_process: if accelerator.is_main_process:
images = []
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
images = []
if args.validation_prompt is not None and global_step % args.validation_steps == 0: if args.validation_prompt is not None and global_step % args.validation_steps == 0:
images = log_validation( images = log_validation(
text_encoder, text_encoder,
......
...@@ -20,6 +20,7 @@ import itertools ...@@ -20,6 +20,7 @@ import itertools
import logging import logging
import math import math
import os import os
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -276,11 +277,7 @@ def parse_args(input_args=None): ...@@ -276,11 +277,7 @@ def parse_args(input_args=None):
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -653,9 +650,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte ...@@ -653,9 +650,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -1221,6 +1216,26 @@ def main(args): ...@@ -1221,6 +1216,26 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -20,6 +20,7 @@ import argparse ...@@ -20,6 +20,7 @@ import argparse
import logging import logging
import math import math
import os import os
import shutil
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -327,11 +328,7 @@ def parse_args(): ...@@ -327,11 +328,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -387,9 +384,7 @@ def main(): ...@@ -387,9 +384,7 @@ def main():
), ),
) )
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
...@@ -867,6 +862,26 @@ def main(): ...@@ -867,6 +862,26 @@ def main():
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -435,8 +435,10 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -435,8 +435,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
pipe(prompt, num_inference_steps=2) pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist # check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) self.assertEqual(
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) {x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
# check can run an intermediate checkpoint # check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
...@@ -474,12 +476,15 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -474,12 +476,15 @@ class ExamplesTestsAccelerate(unittest.TestCase):
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2) pipe(prompt, num_inference_steps=2)
# check old checkpoints do not exist self.assertEqual(
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) {x for x in os.listdir(tmpdir) if "checkpoint" in x},
{
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist # check new checkpoints exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) "checkpoint-4",
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) "checkpoint-6",
},
)
def test_text_to_image_checkpointing_use_ema(self): def test_text_to_image_checkpointing_use_ema(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
...@@ -516,8 +521,10 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -516,8 +521,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
pipe(prompt, num_inference_steps=2) pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist # check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) self.assertEqual(
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) {x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
# check can run an intermediate checkpoint # check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet") unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
...@@ -556,9 +563,773 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -556,9 +563,773 @@ class ExamplesTestsAccelerate(unittest.TestCase):
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None) pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2) pipe(prompt, num_inference_steps=2)
# check old checkpoints do not exist self.assertEqual(
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2"))) {x for x in os.listdir(tmpdir) if "checkpoint" in x},
{
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist # check new checkpoints exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4"))) "checkpoint-4",
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6"))) "checkpoint-6",
},
)
def test_text_to_image_checkpointing_checkpoints_total_limit(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
initial_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 9, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4, 6, 8
initial_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 9
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
# resume and we should try to checkpoint at 10, where we'll have to remove
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
resume_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 11
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
--seed=0
--num_validation_images=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 9, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4, 6, 8
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 9
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
--num_validation_images=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
# resume and we should try to checkpoint at 10, where we'll have to remove
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
resume_run_args = f"""
examples/text_to_image/train_text_to_image_lora.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 11
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--seed=0
--num_validation_images=0
""".split()
run_command(self._launch_args + resume_run_args)
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
def test_unconditional_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
initial_run_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name hf-internal-testing/dummy_image_class_data
--model_config_name_or_path diffusers/ddpm_dummy
--resolution 64
--output_dir {tmpdir}
--train_batch_size 1
--num_epochs 1
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--learning_rate 1e-3
--lr_warmup_steps 5
--checkpointing_steps=2
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + initial_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
initial_run_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name hf-internal-testing/dummy_image_class_data
--model_config_name_or_path diffusers/ddpm_dummy
--resolution 64
--output_dir {tmpdir}
--train_batch_size 1
--num_epochs 1
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--learning_rate 1e-3
--lr_warmup_steps 5
--checkpointing_steps=1
""".split()
run_command(self._launch_args + initial_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"},
)
resume_run_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name hf-internal-testing/dummy_image_class_data
--model_config_name_or_path diffusers/ddpm_dummy
--resolution 64
--output_dir {tmpdir}
--train_batch_size 1
--num_epochs 2
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--learning_rate 1e-3
--lr_warmup_steps 5
--resume_from_checkpoint=checkpoint-6
--checkpointing_steps=2
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
)
def test_textual_inversion_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--train_data_dir docs/source/en/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 3
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=1
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + test_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-3"},
)
def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--train_data_dir docs/source/en/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 3
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=1
""".split()
run_command(self._launch_args + test_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-1", "checkpoint-2", "checkpoint-3"},
)
resume_run_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--train_data_dir docs/source/en/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=1
--resume_from_checkpoint=checkpoint-3
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-3", "checkpoint-4"},
)
def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/instruct_pix2pix/train_instruct_pix2pix.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/instructpix2pix-10-samples
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=7
--checkpointing_steps=2
--checkpoints_total_limit=2
--output_dir {tmpdir}
--seed=0
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/instruct_pix2pix/train_instruct_pix2pix.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/instructpix2pix-10-samples
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=9
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
""".split()
run_command(self._launch_args + test_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/instruct_pix2pix/train_instruct_pix2pix.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/instructpix2pix-10-samples
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=11
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
def test_controlnet_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/controlnet/train_controlnet.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/controlnet/train_controlnet.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/controlnet/train_controlnet.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=11
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
)
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/custom_diffusion/train_custom_diffusion.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=<new1>
--resolution=64
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/custom_diffusion/train_custom_diffusion.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=<new1>
--resolution=64
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/custom_diffusion/train_custom_diffusion.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=<new1>
--resolution=64
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=11
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -362,11 +363,7 @@ def parse_args(): ...@@ -362,11 +363,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -427,9 +424,7 @@ def main(): ...@@ -427,9 +424,7 @@ def main():
) )
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -909,6 +904,26 @@ def main(): ...@@ -909,6 +904,26 @@ def main():
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
from pathlib import Path from pathlib import Path
import datasets import datasets
...@@ -327,11 +328,7 @@ def parse_args(): ...@@ -327,11 +328,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -368,9 +365,7 @@ def main(): ...@@ -368,9 +365,7 @@ def main():
args = parse_args() args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -809,6 +804,26 @@ def main(): ...@@ -809,6 +804,26 @@ def main():
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
...@@ -903,6 +918,7 @@ def main(): ...@@ -903,6 +918,7 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if len(images) != 0:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images]) np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
......
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -394,11 +395,7 @@ def parse_args(): ...@@ -394,11 +395,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -566,9 +563,7 @@ class TextualInversionDataset(Dataset): ...@@ -566,9 +563,7 @@ class TextualInversionDataset(Dataset):
def main(): def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
...@@ -887,6 +882,26 @@ def main(): ...@@ -887,6 +882,26 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -3,6 +3,7 @@ import inspect ...@@ -3,6 +3,7 @@ import inspect
import logging import logging
import math import math
import os import os
import shutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -245,11 +246,7 @@ def parse_args(): ...@@ -245,11 +246,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -287,9 +284,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -287,9 +284,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args): def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -607,6 +602,26 @@ def main(args): ...@@ -607,6 +602,26 @@ def main(args):
global_step += 1 global_step += 1
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
if accelerator.is_main_process: if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment