Unverified Commit 21d71923 authored by Jin Young (Daniel) Sohn's avatar Jin Young (Daniel) Sohn Committed by GitHub
Browse files

Add cache_dir to save features TextDataset (#6879)

* Add cache_dir to save features TextDataset

This is in case the dataset is in a RO filesystem, for which is the case
in tests (GKE TPU tests).

* style
parent 1461aac8
......@@ -125,13 +125,22 @@ class DataTrainingArguments:
)
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
def get_dataset(
args: DataTrainingArguments,
tokenizer: PreTrainedTokenizer,
evaluate: bool = False,
cache_dir: Optional[str] = None,
):
file_path = args.eval_data_file if evaluate else args.train_data_file
if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else:
return TextDataset(
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache
tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
overwrite_cache=args.overwrite_cache,
cache_dir=cache_dir,
)
......@@ -229,8 +238,14 @@ def main():
# Get datasets
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
train_dataset = (
get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
)
eval_dataset = (
get_dataset(data_args, tokenizer=tokenizer, evaluate=True, cache_dir=model_args.cache_dir)
if training_args.do_eval
else None
)
if config.model_type == "xlnet":
data_collator = DataCollatorForPermutationLanguageModeling(
tokenizer=tokenizer,
......
import os
import pickle
import time
from typing import Optional
import torch
from torch.utils.data.dataset import Dataset
......@@ -26,6 +27,7 @@ class TextDataset(Dataset):
file_path: str,
block_size: int,
overwrite_cache=False,
cache_dir: Optional[str] = None,
):
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
......@@ -33,7 +35,7 @@ class TextDataset(Dataset):
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(
directory,
cache_dir if cache_dir is not None else directory,
"cached_lm_{}_{}_{}".format(
tokenizer.__class__.__name__,
str(block_size),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment