Unverified Commit 7e6397a7 authored by Tomo Lazovich's avatar Tomo Lazovich Committed by GitHub
Browse files

[squad] make examples and dataset accessible from SquadDataset object (#6710)

* [squad] make examples and dataset accessible from SquadDataset object

* [squad] add support for legacy cache files
parent ac9702c2
...@@ -103,6 +103,7 @@ class SquadDataset(Dataset): ...@@ -103,6 +103,7 @@ class SquadDataset(Dataset):
mode: Union[str, Split] = Split.train, mode: Union[str, Split] = Split.train,
is_language_sensitive: Optional[bool] = False, is_language_sensitive: Optional[bool] = False,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
dataset_format: Optional[str] = "pt",
): ):
self.args = args self.args = args
self.is_language_sensitive = is_language_sensitive self.is_language_sensitive = is_language_sensitive
...@@ -128,28 +129,43 @@ class SquadDataset(Dataset): ...@@ -128,28 +129,43 @@ class SquadDataset(Dataset):
with FileLock(lock_path): with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache: if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time() start = time.time()
self.features = torch.load(cached_features_file) self.old_features = torch.load(cached_features_file)
# Legacy cache files have only features, while new cache files
# will have dataset and examples also.
self.features = self.old_features["features"]
self.dataset = self.old_features.get("dataset", None)
self.examples = self.old_features.get("examples", None)
logger.info( logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
) )
if self.dataset is None or self.examples is None:
logger.warn(
f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in future run"
)
else: else:
if mode == Split.dev: if mode == Split.dev:
examples = self.processor.get_dev_examples(args.data_dir) self.examples = self.processor.get_dev_examples(args.data_dir)
else: else:
examples = self.processor.get_train_examples(args.data_dir) self.examples = self.processor.get_train_examples(args.data_dir)
self.features = squad_convert_examples_to_features( self.features, self.dataset = squad_convert_examples_to_features(
examples=examples, examples=self.examples,
tokenizer=tokenizer, tokenizer=tokenizer,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride, doc_stride=args.doc_stride,
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
is_training=mode == Split.train, is_training=mode == Split.train,
threads=args.threads, threads=args.threads,
return_dataset=dataset_format,
) )
start = time.time() start = time.time()
torch.save(self.features, cached_features_file) torch.save(
{"features": self.features, "dataset": self.dataset, "examples": self.examples},
cached_features_file,
)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve. # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info( logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment