Unverified Commit 7e6397a7 authored by Tomo Lazovich's avatar Tomo Lazovich Committed by GitHub
Browse files

[squad] make examples and dataset accessible from SquadDataset object (#6710)

* [squad] make examples and dataset accessible from SquadDataset object

* [squad] add support for legacy cache files
parent ac9702c2
......@@ -103,6 +103,7 @@ class SquadDataset(Dataset):
mode: Union[str, Split] = Split.train,
is_language_sensitive: Optional[bool] = False,
cache_dir: Optional[str] = None,
dataset_format: Optional[str] = "pt",
):
self.args = args
self.is_language_sensitive = is_language_sensitive
......@@ -128,28 +129,43 @@ class SquadDataset(Dataset):
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
self.features = torch.load(cached_features_file)
self.old_features = torch.load(cached_features_file)
# Legacy cache files have only features, while new cache files
# will have dataset and examples also.
self.features = self.old_features["features"]
self.dataset = self.old_features.get("dataset", None)
self.examples = self.old_features.get("examples", None)
logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
)
if self.dataset is None or self.examples is None:
logger.warn(
f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in future run"
)
else:
if mode == Split.dev:
examples = self.processor.get_dev_examples(args.data_dir)
self.examples = self.processor.get_dev_examples(args.data_dir)
else:
examples = self.processor.get_train_examples(args.data_dir)
self.examples = self.processor.get_train_examples(args.data_dir)
self.features = squad_convert_examples_to_features(
examples=examples,
self.features, self.dataset = squad_convert_examples_to_features(
examples=self.examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=mode == Split.train,
threads=args.threads,
return_dataset=dataset_format,
)
start = time.time()
torch.save(self.features, cached_features_file)
torch.save(
{"features": self.features, "dataset": self.dataset, "examples": self.examples},
cached_features_file,
)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment