Commit 56d4ba8d authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[run_lm_finetuning] Train from scratch

parent c7f79815
......@@ -28,7 +28,7 @@ import pickle
import random
import re
import shutil
from typing import Tuple
from typing import Dict, List, Tuple
import numpy as np
import torch
......@@ -54,6 +54,7 @@ from transformers import (
OpenAIGPTConfig,
OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
RobertaConfig,
RobertaForMaskedLM,
......@@ -82,11 +83,11 @@ MODEL_CLASSES = {
class TextDataset(Dataset):
def __init__(self, tokenizer, args, file_path="train", block_size=512):
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path="train", block_size=512):
assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(
directory, args.model_name_or_path + "_cached_lm_" + str(block_size) + "_" + filename
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
)
if os.path.exists(cached_features_file) and not args.overwrite_cache:
......@@ -120,13 +121,12 @@ class TextDataset(Dataset):
def load_and_cache_examples(args, tokenizer, evaluate=False):
dataset = TextDataset(
return TextDataset(
tokenizer,
args,
file_path=args.eval_data_file if evaluate else args.train_data_file,
block_size=args.block_size,
)
return dataset
def set_seed(args):
......@@ -137,18 +137,11 @@ def set_seed(args):
torch.cuda.manual_seed_all(args.seed)
def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
if not args.save_total_limit:
return
if args.save_total_limit <= 0:
return
def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
# Check if we should delete older checkpoint(s)
glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
if len(glob_checkpoints) <= args.save_total_limit:
return
ordering_and_checkpoint_path = []
for path in glob_checkpoints:
if use_mtime:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
......@@ -159,6 +152,20 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
if not args.save_total_limit:
return
if args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
if len(checkpoints_sorted) <= args.save_total_limit:
return
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
......@@ -191,7 +198,7 @@ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> T
return inputs, labels
def train(args, train_dataset, model, tokenizer):
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
......@@ -221,7 +228,7 @@ def train(args, train_dataset, model, tokenizer):
)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
if args.model_name_or_path and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
os.path.join(args.model_name_or_path, "scheduler.pt")
):
# Load in optimizer and scheduler states
......@@ -263,7 +270,7 @@ def train(args, train_dataset, model, tokenizer):
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
if args.model_name_or_path and os.path.exists(args.model_name_or_path):
try:
# set global_step to gobal_step of last saved checkpoint from model path
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
......@@ -342,8 +349,7 @@ def train(args, train_dataset, model, tokenizer):
checkpoint_prefix = "checkpoint"
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
os.makedirs(output_dir, exist_ok=True)
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
......@@ -372,14 +378,14 @@ def train(args, train_dataset, model, tokenizer):
return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""):
def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_output_dir = args.output_dir
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)
if args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir, exist_ok=True)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
......@@ -433,11 +439,16 @@ def main():
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
)
parser.add_argument(
"--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
)
# Other parameters
parser.add_argument(
......@@ -447,12 +458,11 @@ def main():
help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
)
parser.add_argument("--model_type", default="bert", type=str, help="The model architecture to be fine-tuned.")
parser.add_argument(
"--model_name_or_path",
default="bert-base-cased",
default=None,
type=str,
help="The model checkpoint for weights initialization.",
help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
)
parser.add_argument(
......@@ -464,19 +474,25 @@ def main():
parser.add_argument(
"--config_name",
default="",
default=None,
type=str,
help="Optional pretrained config name or path if not the same as model_name_or_path",
help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
)
parser.add_argument(
"--tokenizer_name",
default=None,
type=str,
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
)
parser.add_argument(
"--tokenizer_init_args",
default="",
type=str,
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path",
help="If instantiating a new tokenizer, comma-separated list of input args to feed the constructor.",
)
parser.add_argument(
"--cache_dir",
default="",
default=None,
type=str,
help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
)
......@@ -493,9 +509,6 @@ def main():
parser.add_argument(
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
)
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument(
......@@ -563,7 +576,7 @@ def main():
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
raise ValueError(
"BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
"flag (masked language modeling)."
)
if args.eval_data_file is None and args.do_eval:
......@@ -571,6 +584,14 @@ def main():
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument."
)
if args.should_continue:
sorted_checkpoints = _sorted_checkpoints(args)
if len(sorted_checkpoints) == 0:
raise ValueError(
"Used --should_continue but no checkpoint was found in --output_dir."
)
else:
args.model_name_or_path = sorted_checkpoints[-1]
if (
os.path.exists(args.output_dir)
......@@ -627,26 +648,42 @@ def main():
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
if args.config_name:
config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
elif args.model_name_or_path:
config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
else:
config = config_class()
if args.tokenizer_name:
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
elif args.model_name_or_path:
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
else:
logger.warning(
"You are instantiating a new {} tokenizer from scratch. Are you sure this is what you meant to do?"
"To specifiy a pretrained tokenizer name, use --tokenizer_name".format(tokenizer_class.__name__)
)
tokenizer = tokenizer_class(*args.tokenizer_init_args.split(","))
if args.block_size <= 0:
args.block_size = (
tokenizer.max_len_single_sentence
) # Our input block size will be the max possible for the model
args.block_size = tokenizer.max_len_single_sentence
# Our input block size will be the max possible for the model
else:
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
if args.model_name_or_path:
model = model_class.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None,
cache_dir=args.cache_dir,
)
else:
logger.info("Training new model from scratch")
model = model_class(config=config)
model.to(args.device)
if args.local_rank == 0:
......@@ -670,8 +707,8 @@ def main():
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
if args.local_rank in [-1, 0]:
os.makedirs(args.output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
......@@ -687,7 +724,7 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment