Commit 223da391 authored by lintangsutawika's avatar lintangsutawika
Browse files

fix configs

parent 1813bf04
......@@ -5,11 +5,15 @@ dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: "cb hypothesis: {{hypothesis}} premise {{premise}}"
doc_to_target: "{% set answer_choices = ['entailment', 'contradiction', 'neutral'] %}{{answer_choices[label]}}"
doc_to_target: label
doc_to_choice: ['entailment', 'contradiction', 'neutral']
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# - metric: f1
# aggregation: !function "aggregate.cb_multi_fi"
......@@ -5,8 +5,10 @@ dataset_path: super_glue
dataset_name: copa
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: "copa choice1: {{choice1}} choice2: {{choice2}} question: {{question}}"
doc_to_target: "{% set answer_choices = ['False', 'True'] %}{{answer_choices[label]}}"
doc_to_target: label
doc_to_choice: ['False', 'True']
metric_list:
- metric: exact_match
aggregation: mean
......
# group:
# - super-glue-lm-eval-v1
group:
- super-glue-lm-eval-v1
task: record
dataset_path: super_glue
dataset_name: record
......@@ -9,6 +9,10 @@ validation_split: validation
doc_to_text: !function util.doc_to_text
doc_to_target: "{{answers}}"
doc_to_choice: "{{entities}}"
process_results: !function util.process_results
metric_list:
- metric: f1
aggregation: mean
- metric: em
higher_is_better: True
aggregation: mean
......@@ -5,6 +5,7 @@ dataset_path: super_glue
dataset_name: record
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: "record query: {{query}} entities: {{entities}} passage: {{passage}}"
doc_to_target: "{{answers}}"
metric_list:
......
import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics
from lm_eval.api.metrics import metric_max_over_ground_truths
def doc_to_text(doc):
initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
text = initial_text + "\n\n"
......@@ -13,3 +19,25 @@ def format_answer(query, entity):
def doc_to_target(doc):
# We only output the first correct entity in a doc
return format_answer(query=doc["query"], entity=doc["answers"][0])
def process_results(doc, results):
# ReCoRD's evaluation is actually deceptively simple:
# - Pick the maximum likelihood prediction entity
# - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples
max_idx = np.argmax(np.array([result[0] for result in results]))
prediction = doc["entities"][max_idx]
gold_label_set = doc["answers"]
f1 = metric_max_over_ground_truths(
squad_metrics.compute_f1, prediction, gold_label_set
)
em = metric_max_over_ground_truths(
squad_metrics.compute_exact, prediction, gold_label_set
)
return {
"f1": f1,
"em": em,
}
......@@ -5,6 +5,7 @@ dataset_path: super_glue
dataset_name: wsc
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: !function "preprocess_wsc.t5_prompt_doc_to_text"
doc_to_target: label
doc_to_choice: ['False', 'True']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment