"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1e0395e7918f54a53ce14f14279ad07913038bcf"
Unverified Commit 296b01e1 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

add total number checkpoints to training scripts (#2367)



* add total number checkpoints to training scripts

* Update examples/dreambooth/train_dreambooth.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent a3ae4661
...@@ -30,7 +30,7 @@ import torch.utils.checkpoint ...@@ -30,7 +30,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version from packaging import version
from PIL import Image from PIL import Image
...@@ -195,6 +195,16 @@ def parse_args(input_args=None): ...@@ -195,6 +195,16 @@ def parse_args(input_args=None):
"instructions." "instructions."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more details"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -488,11 +498,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -488,11 +498,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.report_to, log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
project_config=accelerator_project_config,
) )
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
......
...@@ -29,7 +29,7 @@ import torch.utils.checkpoint ...@@ -29,7 +29,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -242,6 +242,16 @@ def parse_args(input_args=None): ...@@ -242,6 +242,16 @@ def parse_args(input_args=None):
" training using `--resume_from_checkpoint`." " training using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -526,11 +536,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -526,11 +536,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.report_to, log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
project_config=accelerator_project_config,
) )
if args.report_to == "wandb": if args.report_to == "wandb":
......
...@@ -13,7 +13,7 @@ import torch.nn.functional as F ...@@ -13,7 +13,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -258,6 +258,16 @@ def parse_args(): ...@@ -258,6 +258,16 @@ def parse_args():
" using `--resume_from_checkpoint`." " using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -406,11 +416,14 @@ def main(): ...@@ -406,11 +416,14 @@ def main():
args = parse_args() args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with="tensorboard", log_with="tensorboard",
logging_dir=logging_dir, logging_dir=logging_dir,
accelerator_project_config=accelerator_project_config,
) )
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
......
...@@ -12,7 +12,7 @@ import torch.nn.functional as F ...@@ -12,7 +12,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -254,6 +254,16 @@ def parse_args(): ...@@ -254,6 +254,16 @@ def parse_args():
" using `--resume_from_checkpoint`." " using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -405,11 +415,14 @@ def main(): ...@@ -405,11 +415,14 @@ def main():
args = parse_args() args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with="tensorboard", log_with="tensorboard",
logging_dir=logging_dir, logging_dir=logging_dir,
accelerator_project_config=accelerator_project_config,
) )
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
......
...@@ -15,7 +15,7 @@ import torch.utils.checkpoint ...@@ -15,7 +15,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -170,6 +170,16 @@ def parse_args(input_args=None): ...@@ -170,6 +170,16 @@ def parse_args(input_args=None):
" training using `--resume_from_checkpoint`." " training using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -466,11 +476,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -466,11 +476,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.report_to, log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
project_config=accelerator_project_config,
) )
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
......
...@@ -29,7 +29,7 @@ import torch.utils.checkpoint ...@@ -29,7 +29,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule import ORTModule
...@@ -274,6 +274,16 @@ def parse_args(): ...@@ -274,6 +274,16 @@ def parse_args():
" training using `--resume_from_checkpoint`." " training using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -322,11 +332,14 @@ def main(): ...@@ -322,11 +332,14 @@ def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.report_to, log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
accelerator_project_config=accelerator_project_config,
) )
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
......
...@@ -30,7 +30,7 @@ import torch.utils.checkpoint ...@@ -30,7 +30,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule import ORTModule
...@@ -290,6 +290,16 @@ def parse_args(): ...@@ -290,6 +290,16 @@ def parse_args():
" training using `--resume_from_checkpoint`." " training using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -467,11 +477,14 @@ def main(): ...@@ -467,11 +477,14 @@ def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.report_to, log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
project_config=accelerator_project_config,
) )
if args.report_to == "wandb": if args.report_to == "wandb":
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule import ORTModule
...@@ -231,6 +232,16 @@ def parse_args(): ...@@ -231,6 +232,16 @@ def parse_args():
" training using `--resume_from_checkpoint`." " training using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -265,11 +276,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -265,11 +276,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args): def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.logger, log_with=args.logger,
logging_dir=logging_dir, logging_dir=logging_dir,
project_config=accelerator_project_config,
) )
if args.logger == "tensorboard": if args.logger == "tensorboard":
......
...@@ -30,7 +30,7 @@ import torch.utils.checkpoint ...@@ -30,7 +30,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version from packaging import version
...@@ -275,6 +275,16 @@ def parse_args(): ...@@ -275,6 +275,16 @@ def parse_args():
" training using `--resume_from_checkpoint`." " training using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -333,11 +343,14 @@ def main(): ...@@ -333,11 +343,14 @@ def main():
) )
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.report_to, log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
project_config=accelerator_project_config,
) )
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
......
...@@ -30,7 +30,7 @@ import torch.utils.checkpoint ...@@ -30,7 +30,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms from torchvision import transforms
...@@ -310,6 +310,16 @@ def parse_args(): ...@@ -310,6 +310,16 @@ def parse_args():
" training using `--resume_from_checkpoint`." " training using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -354,11 +364,14 @@ def main(): ...@@ -354,11 +364,14 @@ def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.report_to, log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
project_config=accelerator_project_config,
) )
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
......
...@@ -29,7 +29,7 @@ import torch.utils.checkpoint ...@@ -29,7 +29,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
...@@ -288,6 +288,16 @@ def parse_args(): ...@@ -288,6 +288,16 @@ def parse_args():
" training using `--resume_from_checkpoint`." " training using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -465,11 +475,14 @@ def main(): ...@@ -465,11 +475,14 @@ def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.report_to, log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
project_config=accelerator_project_config,
) )
if args.report_to == "wandb": if args.report_to == "wandb":
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version from packaging import version
...@@ -239,6 +240,16 @@ def parse_args(): ...@@ -239,6 +240,16 @@ def parse_args():
" training using `--resume_from_checkpoint`." " training using `--resume_from_checkpoint`."
), ),
) )
parser.add_argument(
"--checkpointing_steps_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument( parser.add_argument(
"--resume_from_checkpoint", "--resume_from_checkpoint",
type=str, type=str,
...@@ -273,11 +284,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -273,11 +284,14 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args): def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpointing_steps_total_limit)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=args.logger, log_with=args.logger,
logging_dir=logging_dir, logging_dir=logging_dir,
project_config=accelerator_project_config,
) )
if args.logger == "tensorboard": if args.logger == "tensorboard":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment