Unverified Commit 6598967b authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #106 from jon-tow/anli-evaluation

Implement `ANLI` evaluations
parents ec45d0aa 0a53c06e
...@@ -58,9 +58,9 @@ TASK_REGISTRY = { ...@@ -58,9 +58,9 @@ TASK_REGISTRY = {
# "webqs": webqs.WebQs, # not implemented yet # "webqs": webqs.WebQs, # not implemented yet
# "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet # "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet
# "winogrande": winogrande.Winogrande, # not implemented yet # "winogrande": winogrande.Winogrande, # not implemented yet
# "anli_r1": anli.ANLIRound1, # not implemented yet "anli_r1": anli.ANLIRound1,
# "anli_r2": anli.ANLIRound2, # not implemented yet "anli_r2": anli.ANLIRound2,
# "anli_r3": anli.ANLIRound3, # not implemented yet "anli_r3": anli.ANLIRound3,
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
......
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask from . common import HFTask
class ANLIBase(HFTask): class ANLIBase(HFTask):
...@@ -33,7 +35,6 @@ class ANLIBase(HFTask): ...@@ -33,7 +35,6 @@ class ANLIBase(HFTask):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
print(doc)
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning # OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly # of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really* # appended onto the question, with no "Answer:" or even a newline. Do we *really*
...@@ -41,6 +42,9 @@ class ANLIBase(HFTask): ...@@ -41,6 +42,9 @@ class ANLIBase(HFTask):
return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + '\nTrue, False, or Neither?' return doc['premise'] + '\nQuestion: ' + doc['hypothesis'] + '\nTrue, False, or Neither?'
def doc_to_target(self, doc): def doc_to_target(self, doc):
# True = entailment
# False = contradiction
# Neither = neutral
return " " + ["True", "Neither", "False"][doc['label']] return " " + ["True", "Neither", "False"][doc['label']]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -54,8 +58,10 @@ class ANLIBase(HFTask): ...@@ -54,8 +58,10 @@ class ANLIBase(HFTask):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. ll_true, _ = rf.loglikelihood(ctx, " True")
raise NotImplementedError('Evaluation not implemented') ll_neither, _ = rf.loglikelihood(ctx, " Neither")
ll_false, _ = rf.loglikelihood(ctx, " False")
return ll_true, ll_neither, ll_false
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -67,8 +73,11 @@ class ANLIBase(HFTask): ...@@ -67,8 +73,11 @@ class ANLIBase(HFTask):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. gold = doc["label"]
raise NotImplementedError('Evaluation not implemented') pred = np.argmax(results)
return {
"acc": pred == gold
}
def aggregation(self): def aggregation(self):
""" """
...@@ -76,8 +85,9 @@ class ANLIBase(HFTask): ...@@ -76,8 +85,9 @@ class ANLIBase(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -85,8 +95,9 @@ class ANLIBase(HFTask): ...@@ -85,8 +95,9 @@ class ANLIBase(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') "acc": True
}
class ANLIRound1(ANLIBase): class ANLIRound1(ANLIBase):
SPLIT = 1 SPLIT = 1
...@@ -95,4 +106,4 @@ class ANLIRound2(ANLIBase): ...@@ -95,4 +106,4 @@ class ANLIRound2(ANLIBase):
SPLIT = 2 SPLIT = 2
class ANLIRound3(ANLIBase): class ANLIRound3(ANLIBase):
SPLIT = 3 SPLIT = 3
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment