# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import csv import sys from scipy.stats import pearsonr, spearmanr from sklearn.metrics import matthews_corrcoef, f1_score class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with open(input_file, "r", encoding="utf-8-sig") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: if sys.version_info[0] == 2: line = list(unicode(cell, 'utf-8') for cell in line) lines.append(line) return lines def simple_accuracy(preds, labels): return (preds == labels).mean() def acc_and_f1(preds, labels): acc = simple_accuracy(preds, labels) f1 = f1_score(y_true=labels, y_pred=preds) return { "acc": acc, "f1": f1, "acc_and_f1": (acc + f1) / 2, } def pearson_and_spearman(preds, labels): pearson_corr = pearsonr(preds, labels)[0] spearman_corr = spearmanr(preds, labels)[0] return { "pearson": pearson_corr, "spearmanr": spearman_corr, "corr": (pearson_corr + spearman_corr) / 2, } def compute_metrics(task_name, preds, labels): assert len(preds) == len(labels) if task_name == "cola": return {"mcc": matthews_corrcoef(labels, preds)} elif task_name == "sst-2": return {"acc": simple_accuracy(preds, labels)} elif task_name == "mrpc": return acc_and_f1(preds, labels) elif task_name == "sts-b": return pearson_and_spearman(preds, labels) elif task_name == "qqp": return acc_and_f1(preds, labels) elif task_name == "mnli": return {"acc": simple_accuracy(preds, labels)} elif task_name == "mnli-mm": return {"acc": simple_accuracy(preds, labels)} elif task_name == "qnli": return {"acc": simple_accuracy(preds, labels)} elif task_name == "rte": return {"acc": simple_accuracy(preds, labels)} elif task_name == "wnli": return {"acc": simple_accuracy(preds, labels)} else: raise KeyError(task_name)