"llama/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "25911a6e6bd5a0cf209d871c721aa7bc74f59509"
Unverified Commit 21d71923 authored by Jin Young (Daniel) Sohn's avatar Jin Young (Daniel) Sohn Committed by GitHub
Browse files

Add cache_dir to save features TextDataset (#6879)

* Add cache_dir to save features TextDataset

This is in case the dataset is in a RO filesystem, for which is the case
in tests (GKE TPU tests).

* style
parent 1461aac8
...@@ -125,13 +125,22 @@ class DataTrainingArguments: ...@@ -125,13 +125,22 @@ class DataTrainingArguments:
) )
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False): def get_dataset(
args: DataTrainingArguments,
tokenizer: PreTrainedTokenizer,
evaluate: bool = False,
cache_dir: Optional[str] = None,
):
file_path = args.eval_data_file if evaluate else args.train_data_file file_path = args.eval_data_file if evaluate else args.train_data_file
if args.line_by_line: if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else: else:
return TextDataset( return TextDataset(
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
overwrite_cache=args.overwrite_cache,
cache_dir=cache_dir,
) )
...@@ -229,8 +238,14 @@ def main(): ...@@ -229,8 +238,14 @@ def main():
# Get datasets # Get datasets
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None train_dataset = (
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
)
eval_dataset = (
get_dataset(data_args, tokenizer=tokenizer, evaluate=True, cache_dir=model_args.cache_dir)
if training_args.do_eval
else None
)
if config.model_type == "xlnet": if config.model_type == "xlnet":
data_collator = DataCollatorForPermutationLanguageModeling( data_collator = DataCollatorForPermutationLanguageModeling(
tokenizer=tokenizer, tokenizer=tokenizer,
......
import os import os
import pickle import pickle
import time import time
from typing import Optional
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
...@@ -26,6 +27,7 @@ class TextDataset(Dataset): ...@@ -26,6 +27,7 @@ class TextDataset(Dataset):
file_path: str, file_path: str,
block_size: int, block_size: int,
overwrite_cache=False, overwrite_cache=False,
cache_dir: Optional[str] = None,
): ):
assert os.path.isfile(file_path), f"Input file path {file_path} not found" assert os.path.isfile(file_path), f"Input file path {file_path} not found"
...@@ -33,7 +35,7 @@ class TextDataset(Dataset): ...@@ -33,7 +35,7 @@ class TextDataset(Dataset):
directory, filename = os.path.split(file_path) directory, filename = os.path.split(file_path)
cached_features_file = os.path.join( cached_features_file = os.path.join(
directory, cache_dir if cache_dir is not None else directory,
"cached_lm_{}_{}_{}".format( "cached_lm_{}_{}_{}".format(
tokenizer.__class__.__name__, tokenizer.__class__.__name__,
str(block_size), str(block_size),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment