Unverified Commit b231a413 authored by Jin Young Sohn's avatar Jin Young Sohn Committed by GitHub
Browse files

Add cache_dir to save features in GLUE + Differentiate match/mismatch for MNLI metrics (#4621)



* Glue task cleaup

* Enable writing cache to cache_dir in case dataset lives in readOnly
filesystem.
* Differentiate match vs mismatch for MNLI metrics.

* Style

* Fix pytype

* Fix type

* Use cache_dir in mnli mismatch eval dataset

* Small Tweaks
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 70f74234
...@@ -21,7 +21,7 @@ import logging ...@@ -21,7 +21,7 @@ import logging
import os import os
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional from typing import Callable, Dict, Optional
import numpy as np import numpy as np
...@@ -134,16 +134,29 @@ def main(): ...@@ -134,16 +134,29 @@ def main():
) )
# Get datasets # Get datasets
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None train_dataset = (
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
test_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="test") if training_args.do_predict else None )
eval_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
if training_args.do_eval
else None
)
test_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
if training_args.do_predict
else None
)
def compute_metrics(p: EvalPrediction) -> Dict: def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def compute_metrics_fn(p: EvalPrediction):
if output_mode == "classification": if output_mode == "classification":
preds = np.argmax(p.predictions, axis=1) preds = np.argmax(p.predictions, axis=1)
elif output_mode == "regression": elif output_mode == "regression":
preds = np.squeeze(p.predictions) preds = np.squeeze(p.predictions)
return glue_compute_metrics(data_args.task_name, preds, p.label_ids) return glue_compute_metrics(task_name, preds, p.label_ids)
return compute_metrics_fn
# Initialize our Trainer # Initialize our Trainer
trainer = Trainer( trainer = Trainer(
...@@ -151,7 +164,7 @@ def main(): ...@@ -151,7 +164,7 @@ def main():
args=training_args, args=training_args,
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
compute_metrics=compute_metrics, compute_metrics=build_compute_metrics_fn(data_args.task_name),
) )
# Training # Training
...@@ -174,9 +187,12 @@ def main(): ...@@ -174,9 +187,12 @@ def main():
eval_datasets = [eval_dataset] eval_datasets = [eval_dataset]
if data_args.task_name == "mnli": if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev")) eval_datasets.append(
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
)
for eval_dataset in eval_datasets: for eval_dataset in eval_datasets:
trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
eval_result = trainer.evaluate(eval_dataset=eval_dataset) eval_result = trainer.evaluate(eval_dataset=eval_dataset)
output_eval_file = os.path.join( output_eval_file = os.path.join(
...@@ -196,7 +212,9 @@ def main(): ...@@ -196,7 +212,9 @@ def main():
test_datasets = [test_dataset] test_datasets = [test_dataset]
if data_args.task_name == "mnli": if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
test_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test")) test_datasets.append(
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
)
for test_dataset in test_datasets: for test_dataset in test_datasets:
predictions = trainer.predict(test_dataset=test_dataset).predictions predictions = trainer.predict(test_dataset=test_dataset).predictions
......
...@@ -70,6 +70,7 @@ class GlueDataset(Dataset): ...@@ -70,6 +70,7 @@ class GlueDataset(Dataset):
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
limit_length: Optional[int] = None, limit_length: Optional[int] = None,
mode: Union[str, Split] = Split.train, mode: Union[str, Split] = Split.train,
cache_dir: Optional[str] = None,
): ):
self.args = args self.args = args
self.processor = glue_processors[args.task_name]() self.processor = glue_processors[args.task_name]()
...@@ -81,7 +82,7 @@ class GlueDataset(Dataset): ...@@ -81,7 +82,7 @@ class GlueDataset(Dataset):
raise KeyError("mode is not a valid split name") raise KeyError("mode is not a valid split name")
# Load data features from cache or dataset file # Load data features from cache or dataset file
cached_features_file = os.path.join( cached_features_file = os.path.join(
args.data_dir, cache_dir if cache_dir is not None else args.data_dir,
"cached_{}_{}_{}_{}".format( "cached_{}_{}_{}_{}".format(
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
), ),
......
...@@ -63,9 +63,9 @@ if _has_sklearn: ...@@ -63,9 +63,9 @@ if _has_sklearn:
elif task_name == "qqp": elif task_name == "qqp":
return acc_and_f1(preds, labels) return acc_and_f1(preds, labels)
elif task_name == "mnli": elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)} return {"mnli/acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm": elif task_name == "mnli-mm":
return {"acc": simple_accuracy(preds, labels)} return {"mnli-mm/acc": simple_accuracy(preds, labels)}
elif task_name == "qnli": elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)} return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte": elif task_name == "rte":
......
...@@ -553,6 +553,7 @@ class Trainer: ...@@ -553,6 +553,7 @@ class Trainer:
if self.tb_writer: if self.tb_writer:
for k, v in logs.items(): for k, v in logs.items():
self.tb_writer.add_scalar(k, v, self.global_step) self.tb_writer.add_scalar(k, v, self.global_step)
self.tb_writer.flush()
if is_wandb_available(): if is_wandb_available():
wandb.log(logs, step=self.global_step) wandb.log(logs, step=self.global_step)
output = json.dumps({**logs, **{"step": self.global_step}}) output = json.dumps({**logs, **{"step": self.global_step}})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment