Commit d7929899 authored by Lysandre's avatar Lysandre
Browse files

Specify checkpoint in saved file for run_lm_finetuning.py

parent e18f786c
......@@ -63,10 +63,10 @@ MODEL_CLASSES = {
class TextDataset(Dataset):
def __init__(self, tokenizer, file_path='train', block_size=512):
def __init__(self, tokenizer, args, file_path='train', block_size=512):
assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(directory, 'cached_lm_' + str(block_size) + '_' + filename)
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename)
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
......@@ -99,7 +99,7 @@ class TextDataset(Dataset):
def load_and_cache_examples(args, tokenizer, evaluate=False):
dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment