Unverified Commit d9ece823 authored by Boris Dayma's avatar Boris Dayma Committed by GitHub
Browse files

fix(run_language_modeling): use arg overwrite_cache (#4407)

parent d39bf0ac
......@@ -120,7 +120,9 @@ def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, eva
if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else:
return TextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
return TextDataset(
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache
)
def main():
......@@ -216,6 +218,7 @@ def main():
data_args.block_size = min(data_args.block_size, tokenizer.max_len)
# Get datasets
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
data_collator = DataCollatorForLanguageModeling(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment