Unverified Commit caf4abf7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Auto-resume training from checkpoint (#9776)



* Auto-resume training from checkpoint

* Update examples/text-classification/run_glue.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Roll out to other examples
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 0f443436
...@@ -42,7 +42,7 @@ from transformers import ( ...@@ -42,7 +42,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -160,16 +160,20 @@ def main(): ...@@ -160,16 +160,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -356,11 +360,12 @@ def main(): ...@@ -356,11 +360,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
model_path = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
......
...@@ -42,7 +42,7 @@ from transformers import ( ...@@ -42,7 +42,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -171,16 +171,20 @@ def main(): ...@@ -171,16 +171,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -397,11 +401,12 @@ def main(): ...@@ -397,11 +401,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
model_path = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
......
...@@ -44,7 +44,7 @@ from transformers import ( ...@@ -44,7 +44,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -184,16 +184,20 @@ def main(): ...@@ -184,16 +184,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -349,11 +353,12 @@ def main(): ...@@ -349,11 +353,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
model_path = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
......
...@@ -38,7 +38,7 @@ from transformers import ( ...@@ -38,7 +38,7 @@ from transformers import (
XLNetLMHeadModel, XLNetLMHeadModel,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -168,16 +168,20 @@ def main(): ...@@ -168,16 +168,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -378,11 +382,12 @@ def main(): ...@@ -378,11 +382,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
model_path = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
......
...@@ -39,7 +39,7 @@ from transformers import ( ...@@ -39,7 +39,7 @@ from transformers import (
set_seed, set_seed,
) )
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -194,16 +194,20 @@ def main(): ...@@ -194,16 +194,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -334,9 +338,13 @@ def main(): ...@@ -334,9 +338,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -39,7 +39,7 @@ from transformers import ( ...@@ -39,7 +39,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
from utils_qa import postprocess_qa_predictions from utils_qa import postprocess_qa_predictions
...@@ -169,16 +169,20 @@ def main(): ...@@ -169,16 +169,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -453,9 +457,13 @@ def main(): ...@@ -453,9 +457,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -38,7 +38,7 @@ from transformers import ( ...@@ -38,7 +38,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
from utils_qa import postprocess_qa_predictions_with_beam_search from utils_qa import postprocess_qa_predictions_with_beam_search
...@@ -168,16 +168,20 @@ def main(): ...@@ -168,16 +168,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -492,9 +496,13 @@ def main(): ...@@ -492,9 +496,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -40,7 +40,7 @@ from transformers import ( ...@@ -40,7 +40,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -225,16 +225,20 @@ def main(): ...@@ -225,16 +225,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -481,9 +485,13 @@ def main(): ...@@ -481,9 +485,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -38,7 +38,7 @@ from transformers import ( ...@@ -38,7 +38,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
task_to_keys = { task_to_keys = {
...@@ -160,16 +160,20 @@ def main(): ...@@ -160,16 +160,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty. " "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -385,9 +389,13 @@ def main(): ...@@ -385,9 +389,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
metrics = train_result.metrics metrics = train_result.metrics
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
......
...@@ -39,7 +39,7 @@ from transformers import ( ...@@ -39,7 +39,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -154,16 +154,20 @@ def main(): ...@@ -154,16 +154,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -374,9 +378,13 @@ def main(): ...@@ -374,9 +378,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
...@@ -17,7 +17,9 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc ...@@ -17,7 +17,9 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
""" """
import copy import copy
import os
import random import random
import re
import time import time
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
...@@ -75,6 +77,15 @@ class TrainOutput(NamedTuple): ...@@ -75,6 +77,15 @@ class TrainOutput(NamedTuple):
PREFIX_CHECKPOINT_DIR = "checkpoint" PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d)+$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [path for path in content if _re_checkpoint.search(path) is not None and os.path.isdir(path)]
if len(checkpoints) == 0:
return
return max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))
class EvaluationStrategy(ExplicitEnum): class EvaluationStrategy(ExplicitEnum):
......
...@@ -39,7 +39,7 @@ from transformers import ( ...@@ -39,7 +39,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -168,16 +168,20 @@ def main(): ...@@ -168,16 +168,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
): raise ValueError(
raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. "
f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome."
"Use --overwrite_output_dir to overcome." )
) elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -334,17 +338,21 @@ def main(): ...@@ -334,17 +338,21 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
{%- if cookiecutter.can_train_from_scratch == "False" %} {%- if cookiecutter.can_train_from_scratch == "False" %}
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
{%- elif cookiecutter.can_train_from_scratch == "True" %} {%- elif cookiecutter.can_train_from_scratch == "True" %}
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
train_result = trainer.train(model_path=model_path) model_path = None
{% endif %} {% endif %}
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment