"vscode:/vscode.git/clone" did not exist on "3c5c567507032da4c9dec00facf9ea4043e634eb"
Unverified Commit b231a413 authored by Jin Young Sohn's avatar Jin Young Sohn Committed by GitHub
Browse files

Add cache_dir to save features in GLUE + Differentiate match/mismatch for MNLI metrics (#4621)



* Glue task cleaup

* Enable writing cache to cache_dir in case dataset lives in readOnly
filesystem.
* Differentiate match vs mismatch for MNLI metrics.

* Style

* Fix pytype

* Fix type

* Use cache_dir in mnli mismatch eval dataset

* Small Tweaks
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 70f74234
......@@ -21,7 +21,7 @@ import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Callable, Dict, Optional
import numpy as np
......@@ -134,16 +134,29 @@ def main():
)
# Get datasets
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None
test_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="test") if training_args.do_predict else None
train_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
)
eval_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
if training_args.do_eval
else None
)
test_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
if training_args.do_predict
else None
)
def compute_metrics(p: EvalPrediction) -> Dict:
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def compute_metrics_fn(p: EvalPrediction):
if output_mode == "classification":
preds = np.argmax(p.predictions, axis=1)
elif output_mode == "regression":
preds = np.squeeze(p.predictions)
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)
return glue_compute_metrics(task_name, preds, p.label_ids)
return compute_metrics_fn
# Initialize our Trainer
trainer = Trainer(
......@@ -151,7 +164,7 @@ def main():
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
compute_metrics=build_compute_metrics_fn(data_args.task_name),
)
# Training
......@@ -174,9 +187,12 @@ def main():
eval_datasets = [eval_dataset]
if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev"))
eval_datasets.append(
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
)
for eval_dataset in eval_datasets:
trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
output_eval_file = os.path.join(
......@@ -196,7 +212,9 @@ def main():
test_datasets = [test_dataset]
if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
test_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test"))
test_datasets.append(
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
)
for test_dataset in test_datasets:
predictions = trainer.predict(test_dataset=test_dataset).predictions
......
......@@ -70,6 +70,7 @@ class GlueDataset(Dataset):
tokenizer: PreTrainedTokenizer,
limit_length: Optional[int] = None,
mode: Union[str, Split] = Split.train,
cache_dir: Optional[str] = None,
):
self.args = args
self.processor = glue_processors[args.task_name]()
......@@ -81,7 +82,7 @@ class GlueDataset(Dataset):
raise KeyError("mode is not a valid split name")
# Load data features from cache or dataset file
cached_features_file = os.path.join(
args.data_dir,
cache_dir if cache_dir is not None else args.data_dir,
"cached_{}_{}_{}_{}".format(
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
),
......
......@@ -63,9 +63,9 @@ if _has_sklearn:
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)}
return {"mnli/acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"acc": simple_accuracy(preds, labels)}
return {"mnli-mm/acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
......
......@@ -553,6 +553,7 @@ class Trainer:
if self.tb_writer:
for k, v in logs.items():
self.tb_writer.add_scalar(k, v, self.global_step)
self.tb_writer.flush()
if is_wandb_available():
wandb.log(logs, step=self.global_step)
output = json.dumps({**logs, **{"step": self.global_step}})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment