Unverified Commit d49e2dd5 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

manual check for checkpoints_total_limit instead of using accelerate (#3681)

* manual check for checkpoints_total_limit instead of using accelerate

* remove controlnet_conditioning_embedding_out_channels
parent 7bfd2375
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -307,11 +308,7 @@ def parse_args(input_args=None): ...@@ -307,11 +308,7 @@ def parse_args(input_args=None):
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more details"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -716,9 +713,7 @@ def collate_fn(examples): ...@@ -716,9 +713,7 @@ def collate_fn(examples):
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -1060,6 +1055,26 @@ def main(args): ...@@ -1060,6 +1055,26 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -446,11 +447,7 @@ def parse_args(input_args=None): ...@@ -446,11 +447,7 @@ def parse_args(input_args=None):
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -637,9 +634,7 @@ def parse_args(input_args=None): ...@@ -637,9 +634,7 @@ def parse_args(input_args=None):
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -1171,6 +1166,26 @@ def main(args): ...@@ -1171,6 +1166,26 @@ def main(args):
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -20,6 +20,7 @@ import itertools ...@@ -20,6 +20,7 @@ import itertools
import logging import logging
import math import math
import os import os
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -771,9 +772,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte ...@@ -771,9 +772,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -1270,12 +1269,33 @@ def main(args): ...@@ -1270,12 +1269,33 @@ def main(args):
global_step += 1 global_step += 1
if accelerator.is_main_process: if accelerator.is_main_process:
images = []
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
images = []
if args.validation_prompt is not None and global_step % args.validation_steps == 0: if args.validation_prompt is not None and global_step % args.validation_steps == 0:
images = log_validation( images = log_validation(
text_encoder, text_encoder,
......
...@@ -20,6 +20,7 @@ import itertools ...@@ -20,6 +20,7 @@ import itertools
import logging import logging
import math import math
import os import os
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -276,11 +277,7 @@ def parse_args(input_args=None): ...@@ -276,11 +277,7 @@ def parse_args(input_args=None):
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -653,9 +650,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte ...@@ -653,9 +650,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -1221,6 +1216,26 @@ def main(args): ...@@ -1221,6 +1216,26 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -20,6 +20,7 @@ import argparse ...@@ -20,6 +20,7 @@ import argparse
import logging import logging
import math import math
import os import os
import shutil
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -327,11 +328,7 @@ def parse_args(): ...@@ -327,11 +328,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -387,9 +384,7 @@ def main(): ...@@ -387,9 +384,7 @@ def main():
), ),
) )
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
...@@ -867,6 +862,26 @@ def main(): ...@@ -867,6 +862,26 @@ def main():
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
This diff is collapsed.
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -362,11 +363,7 @@ def parse_args(): ...@@ -362,11 +363,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -427,9 +424,7 @@ def main(): ...@@ -427,9 +424,7 @@ def main():
) )
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -909,6 +904,26 @@ def main(): ...@@ -909,6 +904,26 @@ def main():
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
from pathlib import Path from pathlib import Path
import datasets import datasets
...@@ -327,11 +328,7 @@ def parse_args(): ...@@ -327,11 +328,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -368,9 +365,7 @@ def main(): ...@@ -368,9 +365,7 @@ def main():
args = parse_args() args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -809,6 +804,26 @@ def main(): ...@@ -809,6 +804,26 @@ def main():
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
...@@ -903,18 +918,19 @@ def main(): ...@@ -903,18 +918,19 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if len(images) != 0:
np_images = np.stack([np.asarray(img) for img in images]) if tracker.name == "tensorboard":
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") np_images = np.stack([np.asarray(img) for img in images])
if tracker.name == "wandb": tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
tracker.log( if tracker.name == "wandb":
{ tracker.log(
"test": [ {
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") "test": [
for i, image in enumerate(images) wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
] for i, image in enumerate(images)
} ]
) }
)
accelerator.end_training() accelerator.end_training()
......
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import math import math
import os import os
import random import random
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -394,11 +395,7 @@ def parse_args(): ...@@ -394,11 +395,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -566,9 +563,7 @@ class TextualInversionDataset(Dataset): ...@@ -566,9 +563,7 @@ class TextualInversionDataset(Dataset):
def main(): def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
...@@ -887,6 +882,26 @@ def main(): ...@@ -887,6 +882,26 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
...@@ -3,6 +3,7 @@ import inspect ...@@ -3,6 +3,7 @@ import inspect
import logging import logging
import math import math
import os import os
import shutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -245,11 +246,7 @@ def parse_args(): ...@@ -245,11 +246,7 @@ def parse_args():
"--checkpoints_total_limit", "--checkpoints_total_limit",
type=int, type=int,
default=None, default=None,
help=( help=("Max number of checkpoints to store."),
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
) )
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
...@@ -287,9 +284,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -287,9 +284,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args): def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration( accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
...@@ -607,6 +602,26 @@ def main(args): ...@@ -607,6 +602,26 @@ def main(args):
global_step += 1 global_step += 1
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
if accelerator.is_main_process: if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment