Commit 56d4ba8d authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[run_lm_finetuning] Train from scratch

parent c7f79815
...@@ -28,7 +28,7 @@ import pickle ...@@ -28,7 +28,7 @@ import pickle
import random import random
import re import re
import shutil import shutil
from typing import Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -54,6 +54,7 @@ from transformers import ( ...@@ -54,6 +54,7 @@ from transformers import (
OpenAIGPTConfig, OpenAIGPTConfig,
OpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer, OpenAIGPTTokenizer,
PreTrainedModel,
PreTrainedTokenizer, PreTrainedTokenizer,
RobertaConfig, RobertaConfig,
RobertaForMaskedLM, RobertaForMaskedLM,
...@@ -82,11 +83,11 @@ MODEL_CLASSES = { ...@@ -82,11 +83,11 @@ MODEL_CLASSES = {
class TextDataset(Dataset): class TextDataset(Dataset):
def __init__(self, tokenizer, args, file_path="train", block_size=512): def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path="train", block_size=512):
assert os.path.isfile(file_path) assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path) directory, filename = os.path.split(file_path)
cached_features_file = os.path.join( cached_features_file = os.path.join(
directory, args.model_name_or_path + "_cached_lm_" + str(block_size) + "_" + filename directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
) )
if os.path.exists(cached_features_file) and not args.overwrite_cache: if os.path.exists(cached_features_file) and not args.overwrite_cache:
...@@ -120,13 +121,12 @@ class TextDataset(Dataset): ...@@ -120,13 +121,12 @@ class TextDataset(Dataset):
def load_and_cache_examples(args, tokenizer, evaluate=False): def load_and_cache_examples(args, tokenizer, evaluate=False):
dataset = TextDataset( return TextDataset(
tokenizer, tokenizer,
args, args,
file_path=args.eval_data_file if evaluate else args.train_data_file, file_path=args.eval_data_file if evaluate else args.train_data_file,
block_size=args.block_size, block_size=args.block_size,
) )
return dataset
def set_seed(args): def set_seed(args):
...@@ -137,18 +137,11 @@ def set_seed(args): ...@@ -137,18 +137,11 @@ def set_seed(args):
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False): def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
if not args.save_total_limit: ordering_and_checkpoint_path = []
return
if args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix))) glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
if len(glob_checkpoints) <= args.save_total_limit:
return
ordering_and_checkpoint_path = []
for path in glob_checkpoints: for path in glob_checkpoints:
if use_mtime: if use_mtime:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
...@@ -159,6 +152,20 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False): ...@@ -159,6 +152,20 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
checkpoints_sorted = sorted(ordering_and_checkpoint_path) checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
if not args.save_total_limit:
return
if args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
if len(checkpoints_sorted) <= args.save_total_limit:
return
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit) number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted: for checkpoint in checkpoints_to_be_deleted:
...@@ -191,7 +198,7 @@ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> T ...@@ -191,7 +198,7 @@ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> T
return inputs, labels return inputs, labels
def train(args, train_dataset, model, tokenizer): def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter() tb_writer = SummaryWriter()
...@@ -221,7 +228,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -221,7 +228,7 @@ def train(args, train_dataset, model, tokenizer):
) )
# Check if saved optimizer or scheduler states exist # Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( if args.model_name_or_path and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
os.path.join(args.model_name_or_path, "scheduler.pt") os.path.join(args.model_name_or_path, "scheduler.pt")
): ):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
...@@ -263,7 +270,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -263,7 +270,7 @@ def train(args, train_dataset, model, tokenizer):
epochs_trained = 0 epochs_trained = 0
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path): if args.model_name_or_path and os.path.exists(args.model_name_or_path):
try: try:
# set global_step to gobal_step of last saved checkpoint from model path # set global_step to gobal_step of last saved checkpoint from model path
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
...@@ -342,8 +349,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -342,8 +349,7 @@ def train(args, train_dataset, model, tokenizer):
checkpoint_prefix = "checkpoint" checkpoint_prefix = "checkpoint"
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir)
model_to_save = ( model_to_save = (
model.module if hasattr(model, "module") else model model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training ) # Take care of distributed/parallel training
...@@ -372,14 +378,14 @@ def train(args, train_dataset, model, tokenizer): ...@@ -372,14 +378,14 @@ def train(args, train_dataset, model, tokenizer):
return global_step, tr_loss / global_step return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""): def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
# Loop to handle MNLI double evaluation (matched, mis-matched) # Loop to handle MNLI double evaluation (matched, mis-matched)
eval_output_dir = args.output_dir eval_output_dir = args.output_dir
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir) os.makedirs(eval_output_dir, exist_ok=True)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly # Note that DistributedSampler samples randomly
...@@ -433,11 +439,16 @@ def main(): ...@@ -433,11 +439,16 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
default=None,
type=str, type=str,
required=True, required=True,
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
parser.add_argument(
"--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
)
parser.add_argument(
"--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
)
# Other parameters # Other parameters
parser.add_argument( parser.add_argument(
...@@ -447,12 +458,11 @@ def main(): ...@@ -447,12 +458,11 @@ def main():
help="An optional input evaluation data file to evaluate the perplexity on (a text file).", help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
) )
parser.add_argument("--model_type", default="bert", type=str, help="The model architecture to be fine-tuned.")
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
default="bert-base-cased", default=None,
type=str, type=str,
help="The model checkpoint for weights initialization.", help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
) )
parser.add_argument( parser.add_argument(
...@@ -464,19 +474,25 @@ def main(): ...@@ -464,19 +474,25 @@ def main():
parser.add_argument( parser.add_argument(
"--config_name", "--config_name",
default="", default=None,
type=str, type=str,
help="Optional pretrained config name or path if not the same as model_name_or_path", help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
) )
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
default=None,
type=str,
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
)
parser.add_argument(
"--tokenizer_init_args",
default="", default="",
type=str, type=str,
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path", help="If instantiating a new tokenizer, comma-separated list of input args to feed the constructor.",
) )
parser.add_argument( parser.add_argument(
"--cache_dir", "--cache_dir",
default="", default=None,
type=str, type=str,
help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)", help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
) )
...@@ -493,9 +509,6 @@ def main(): ...@@ -493,9 +509,6 @@ def main():
parser.add_argument( parser.add_argument(
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
) )
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument( parser.add_argument(
...@@ -563,7 +576,7 @@ def main(): ...@@ -563,7 +576,7 @@ def main():
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm: if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
raise ValueError( raise ValueError(
"BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
"flag (masked language modeling)." "flag (masked language modeling)."
) )
if args.eval_data_file is None and args.do_eval: if args.eval_data_file is None and args.do_eval:
...@@ -571,6 +584,14 @@ def main(): ...@@ -571,6 +584,14 @@ def main():
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument." "or remove the --do_eval argument."
) )
if args.should_continue:
sorted_checkpoints = _sorted_checkpoints(args)
if len(sorted_checkpoints) == 0:
raise ValueError(
"Used --should_continue but no checkpoint was found in --output_dir."
)
else:
args.model_name_or_path = sorted_checkpoints[-1]
if ( if (
os.path.exists(args.output_dir) os.path.exists(args.output_dir)
...@@ -627,26 +648,42 @@ def main(): ...@@ -627,26 +648,42 @@ def main():
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path, if args.config_name:
cache_dir=args.cache_dir if args.cache_dir else None, config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
) elif args.model_name_or_path:
tokenizer = tokenizer_class.from_pretrained( config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, else:
do_lower_case=args.do_lower_case, config = config_class()
cache_dir=args.cache_dir if args.cache_dir else None,
if args.tokenizer_name:
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
elif args.model_name_or_path:
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
else:
logger.warning(
"You are instantiating a new {} tokenizer from scratch. Are you sure this is what you meant to do?"
"To specifiy a pretrained tokenizer name, use --tokenizer_name".format(tokenizer_class.__name__)
) )
tokenizer = tokenizer_class(*args.tokenizer_init_args.split(","))
if args.block_size <= 0: if args.block_size <= 0:
args.block_size = ( args.block_size = tokenizer.max_len_single_sentence
tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
) # Our input block size will be the max possible for the model else:
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
if args.model_name_or_path:
model = model_class.from_pretrained( model = model_class.from_pretrained(
args.model_name_or_path, args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
cache_dir=args.cache_dir if args.cache_dir else None, cache_dir=args.cache_dir,
) )
else:
logger.info("Training new model from scratch")
model = model_class(config=config)
model.to(args.device) model.to(args.device)
if args.local_rank == 0: if args.local_rank == 0:
...@@ -670,8 +707,8 @@ def main(): ...@@ -670,8 +707,8 @@ def main():
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained() # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed # Create output directory if needed
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
os.makedirs(args.output_dir) os.makedirs(args.output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`. # Save a trained model, configuration and tokenizer using `save_pretrained()`.
...@@ -687,7 +724,7 @@ def main(): ...@@ -687,7 +724,7 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir) model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device) model.to(args.device)
# Evaluation # Evaluation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment