Unverified Commit cda25fef authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'main' into standardize_metrics

parents dfb41835 4d10ad56
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -12,4 +12,4 @@ doc_to_choice: "['''{{answer}}\\nIs the answer correct? yes''', '''{{answer}}\\n ...@@ -12,4 +12,4 @@ doc_to_choice: "['''{{answer}}\\nIs the answer correct? yes''', '''{{answer}}\\n
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
- version: 2.0 version: 2.0
...@@ -20,4 +20,4 @@ metric_list: ...@@ -20,4 +20,4 @@ metric_list:
aggregation: !function t5_utils.agg_em aggregation: !function t5_utils.agg_em
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -5,7 +5,6 @@ import sklearn.metrics ...@@ -5,7 +5,6 @@ import sklearn.metrics
def f1(predictions, references): # This is a passthrough function def f1(predictions, references): # This is a passthrough function
_prediction = predictions[0] _prediction = predictions[0]
_reference = references[0].split("_")[-1] _reference = references[0].split("_")[-1]
string_label = ["False", "True"] string_label = ["False", "True"]
...@@ -20,7 +19,6 @@ def f1(predictions, references): # This is a passthrough function ...@@ -20,7 +19,6 @@ def f1(predictions, references): # This is a passthrough function
def agg_f1(items): def agg_f1(items):
predictions, references = zip(*items) predictions, references = zip(*items)
references, predictions = np.asarray(references), np.asarray(predictions) references, predictions = np.asarray(references), np.asarray(predictions)
...@@ -28,7 +26,6 @@ def agg_f1(items): ...@@ -28,7 +26,6 @@ def agg_f1(items):
def em(predictions, references): # This is a passthrough function def em(predictions, references): # This is a passthrough function
_prediction = predictions[0] _prediction = predictions[0]
_group, _reference = references[0].split("_") _group, _reference = references[0].split("_")
string_label = ["False", "True"] string_label = ["False", "True"]
......
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
higher_is_better: True higher_is_better: True
aggregation: mean aggregation: mean
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
aggregation: !function t5_utils.squad_f1_agg aggregation: !function t5_utils.squad_f1_agg
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -3,14 +3,12 @@ import string ...@@ -3,14 +3,12 @@ import string
import collections import collections
import numpy as np import numpy as np
from tqdm import tqdm from datasets import Dataset
from datasets import Dataset, concatenate_datasets
from lm_eval.api.metrics import metric_max_over_ground_truths from lm_eval.api.metrics import metric_max_over_ground_truths
def doc_to_text(doc): def doc_to_text(doc):
passage = doc["passage"] passage = doc["passage"]
passage = re.sub(r"(\.|\?|\!|\"|\')\n@highlight\n", r"\1 ", passage) passage = re.sub(r"(\.|\?|\!|\"|\')\n@highlight\n", r"\1 ", passage)
passage = re.sub(r"\n@highlight\n", ". ", passage) passage = re.sub(r"\n@highlight\n", ". ", passage)
...@@ -34,7 +32,6 @@ def process_docs(dataset): ...@@ -34,7 +32,6 @@ def process_docs(dataset):
} }
answers = doc.pop("answers") answers = doc.pop("answers")
for idx, answer in enumerate(answers): for idx, answer in enumerate(answers):
for key in split_doc.keys(): for key in split_doc.keys():
if key in doc: if key in doc:
split_doc[key].append(doc[key]) split_doc[key].append(doc[key])
......
...@@ -12,4 +12,4 @@ doc_to_choice: ['True', 'False'] ...@@ -12,4 +12,4 @@ doc_to_choice: ['True', 'False']
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -12,4 +12,4 @@ doc_to_choice: ['no', 'yes'] ...@@ -12,4 +12,4 @@ doc_to_choice: ['no', 'yes']
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -19,4 +19,4 @@ metric_list: ...@@ -19,4 +19,4 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -12,4 +12,4 @@ doc_to_choice: ['no', 'yes'] ...@@ -12,4 +12,4 @@ doc_to_choice: ['no', 'yes']
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -20,4 +20,4 @@ filter_list: ...@@ -20,4 +20,4 @@ filter_list:
filter: filter:
- function: !function t5_utils.WSCPostprocess - function: !function t5_utils.WSCPostprocess
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -8,7 +8,6 @@ def doc_to_text(x): ...@@ -8,7 +8,6 @@ def doc_to_text(x):
def _wsc_inputs(x): def _wsc_inputs(x):
words = x["text"].split(" ") words = x["text"].split(" ")
# We would need some special logic to handle the case where the pronoun is the # We would need some special logic to handle the case where the pronoun is the
...@@ -55,7 +54,6 @@ def _wsc_inputs(x): ...@@ -55,7 +54,6 @@ def _wsc_inputs(x):
class WSCPostprocess(Filter): class WSCPostprocess(Filter):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.determiners = { self.determiners = {
"a", "a",
"an", "an",
...@@ -86,10 +84,8 @@ class WSCPostprocess(Filter): ...@@ -86,10 +84,8 @@ class WSCPostprocess(Filter):
return " ".join([w for w in s.split(" ") if w not in self.determiners]) return " ".join([w for w in s.split(" ") if w not in self.determiners])
def apply(self, resps, docs): def apply(self, resps, docs):
filtered_resps = [] filtered_resps = []
for prediction, reference in zip(*(resps, docs["span1_text"])): for prediction, reference in zip(*(resps, docs["span1_text"])):
prediction = self.clean(prediction[0]) prediction = self.clean(prediction[0])
reference = self.clean(reference) reference = self.clean(reference)
......
...@@ -16,4 +16,4 @@ metric_list: ...@@ -16,4 +16,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -15,4 +15,4 @@ metric_list: ...@@ -15,4 +15,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
import argparse import argparse
from typing import Dict, List
import yaml import yaml
import sacrebleu
try: try:
import pycountry import pycountry
......
...@@ -14,4 +14,4 @@ generation_kwargs: ...@@ -14,4 +14,4 @@ generation_kwargs:
temperature: 0.0 temperature: 0.0
repeats: 1 repeats: 1
metadata: metadata:
- version: 0.0 version: 0.0
...@@ -28,4 +28,4 @@ metric_list: ...@@ -28,4 +28,4 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
metadata: metadata:
- version: 2.0 version: 2.0
...@@ -76,4 +76,4 @@ metric_list: ...@@ -76,4 +76,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 2.0 version: 2.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment