Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
34f591af
Commit
34f591af
authored
Apr 25, 2022
by
jon-tow
Browse files
Add multiple tasks
parent
2bfa4518
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
81 additions
and
387 deletions
+81
-387
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+2
-0
lm_eval/tasks/anli.py
lm_eval/tasks/anli.py
+2
-47
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+22
-12
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+30
-57
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+15
-95
lm_eval/tasks/superglue.py
lm_eval/tasks/superglue.py
+10
-176
No files found.
lm_eval/tasks/__init__.py
View file @
34f591af
...
...
@@ -52,6 +52,7 @@ from . import blimp
from
.
import
asdiv
from
.
import
gsm8k
from
.
import
storycloze
from
.
import
e2e_nlg_cleaned
########################################
# Translation tasks
...
...
@@ -124,6 +125,7 @@ TASK_REGISTRY = {
# Science related
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"e2e_nlg_cleaned"
:
e2e_nlg_cleaned
.
E2E_NLG_Cleaned
,
"qasper"
:
qasper
.
QASPER
,
...
...
lm_eval/tasks/anli.py
View file @
34f591af
...
...
@@ -10,7 +10,7 @@ provided explanations.
Homepage: "https://github.com/facebookresearch/anli"
"""
import
numpy
as
np
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
rf
,
PromptSource
Task
from
lm_eval.metrics
import
mean
...
...
@@ -30,7 +30,7 @@ _CITATION = """
"""
class
ANLIBase
(
Task
):
class
ANLIBase
(
PromptSource
Task
):
VERSION
=
0
DATASET_PATH
=
"anli"
DATASET_NAME
=
None
...
...
@@ -59,51 +59,6 @@ class ANLIBase(Task):
if
self
.
has_test_docs
():
return
self
.
dataset
[
"test_r"
+
str
(
self
.
SPLIT
)]
def
doc_to_text
(
self
,
doc
):
# OA does this a bit weirdly: they prepend "anli 1: anli 1: " to the beginning
# of the prompt (yes, repeating it!). also, " True, False, or Neither?" is directly
# appended onto the question, with no "Answer:" or even a newline. Do we *really*
# want to do it exactly as OA did?
return
doc
[
'premise'
]
+
'
\n
Question: '
+
doc
[
'hypothesis'
]
+
' True, False, or Neither?
\n
Answer:'
def
doc_to_target
(
self
,
doc
):
# True = entailment
# False = contradiction
# Neither = neutral
return
" "
+
[
"True"
,
"Neither"
,
"False"
][
doc
[
'label'
]]
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_true
,
_
=
rf
.
loglikelihood
(
ctx
,
" True"
)
ll_neither
,
_
=
rf
.
loglikelihood
(
ctx
,
" Neither"
)
ll_false
,
_
=
rf
.
loglikelihood
(
ctx
,
" False"
)
return
ll_true
,
ll_neither
,
ll_false
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
gold
=
doc
[
"label"
]
pred
=
np
.
argmax
(
results
)
return
{
"acc"
:
pred
==
gold
}
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
...
...
lm_eval/tasks/coqa.py
View file @
34f591af
...
...
@@ -67,6 +67,7 @@ class CoQA(PromptSourceTask):
# answers.append(additional_answer_for_turn)
# return answers
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
# tests for exact match and on the normalised answer (compute_exact)
...
...
@@ -90,19 +91,21 @@ class CoQA(PromptSourceTask):
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
}
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
def
eos_token
(
self
):
return
"
\n
"
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
cont_request
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:"
])
return
cont_request
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
# Requests which will be sent to the LM.
# :param doc:
# The document as returned from training_docs, validation_docs, or test_docs.
# :param ctx: str
# The context string, generated by fewshot_context. This includes the natural
# language description, as well as the few shot examples, and the question
# part of the document for `doc`.
# """
# return cont_request
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
...
...
@@ -116,6 +119,13 @@ class CoQA(PromptSourceTask):
"""
target
=
self
.
doc_to_target
(
doc
).
strip
()
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
print
(
"*"
*
80
)
print
(
f
"DOC:
{
doc
}
"
)
# print(f"PS: {self.prompt.apply(doc)}")
print
(
f
"TEXT:
{
self
.
doc_to_text
(
doc
)
}
"
)
print
(
f
"TARGET:
{
target
}
END TARGET"
)
print
(
pred
)
print
(
"*"
*
80
)
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
...
...
lm_eval/tasks/drop.py
View file @
34f591af
...
...
@@ -39,7 +39,7 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class
DROP
(
PromptSourceTask
):
VERSION
=
1
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
drop
.
drop
)
DATASET_PATH
=
"drop"
#
inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME
=
None
def
has_training_docs
(
self
):
...
...
@@ -52,51 +52,13 @@ class DROP(PromptSourceTask):
return
False
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
map
(
self
.
_process_doc
,
self
.
dataset
[
"train"
]))
return
self
.
_training_docs
# if self._training_docs is None:
# self._training_docs = list()
# return self._training_docs
return
self
.
dataset
[
"train"
]
def
validation_docs
(
self
):
return
map
(
self
.
_process_doc
,
self
.
dataset
[
"validation"
])
def
_process_doc
(
self
,
doc
):
return
{
"id"
:
doc
[
"query_id"
],
"passage"
:
doc
[
"passage"
],
"question"
:
doc
[
"question"
],
"answers"
:
self
.
get_answers
(
doc
),
}
@
classmethod
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
(
{
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
}
)
return
vas
answers
=
[]
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
]
)
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
continue
answers_set
.
add
(
answer
)
answers
.
append
(
answer
)
return
answers
return
self
.
dataset
[
"validation"
]
@
classmethod
def
parse_answer
(
cls
,
answer
):
...
...
@@ -117,19 +79,21 @@ class DROP(PromptSourceTask):
# def doc_to_target(self, doc):
# return " " + ", ".join(doc["answers"][0])
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
conts
=
[
rf
.
greedy_until
(
ctx
,
[
"."
])]
return
conts
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
# Requests which will be sent to the LM.
# :param doc:
# The document as returned from training_docs, validation_docs, or test_docs.
# :param ctx: str
# The context string, generated by fewshot_context. This includes the natural
# language description, as well as the few shot examples, and the question
# part of the document for `doc`.
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
def
eos_token
(
self
):
return
"."
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
...
...
@@ -145,6 +109,15 @@ class DROP(PromptSourceTask):
pred
=
results
[
0
].
strip
()
target
=
self
.
doc_to_target
(
doc
).
strip
()
print
(
"*"
*
80
)
print
(
f
"DOC:
{
doc
}
"
)
print
(
f
"PS:
{
self
.
prompt
.
apply
(
doc
)
}
"
)
print
(
f
"TEXT:
{
self
.
doc_to_text
(
doc
)
}
"
)
print
(
f
"TARGET:
{
target
}
END TARGET"
)
print
(
pred
)
print
(
"*"
*
80
)
preds
=
[
pred
]
golds
=
[
target
]
...
...
lm_eval/tasks/glue.py
View file @
34f591af
...
...
@@ -45,7 +45,7 @@ _CITATION = """
# Single-Sentence Tasks
class
CoLA
(
Task
):
class
CoLA
(
PromptSource
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"cola"
...
...
@@ -67,23 +67,20 @@ class CoLA(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: Does this sentence make sense?
\n
Answer:"
.
format
(
doc
[
"sentence"
])
def
doc_to_target
(
self
,
doc
):
return
" {}"
.
format
({
1
:
"yes"
,
0
:
"no"
}[
doc
[
"label"
]])
def
construct_requests
(
self
,
doc
,
ctx
):
ll_true
,
_
=
rf
.
loglikelihood
(
ctx
,
" yes"
)
ll_false
,
_
=
rf
.
loglikelihood
(
ctx
,
" no"
)
return
ll_true
,
ll_false
def
process_results
(
self
,
doc
,
results
):
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
pred
=
np
.
argmax
(
results
)
target
=
answer_choices_list
.
index
(
self
.
doc_to_target
(
doc
).
strip
())
print
(
"*"
*
80
)
print
(
f
"DOC:
{
doc
}
"
)
print
(
f
"TEXT:
{
self
.
doc_to_text
(
doc
)
}
"
)
print
(
f
"STRING TARGET:
{
self
.
doc_to_target
(
doc
)
}
END TARGET"
)
print
(
f
"TARGET:
{
target
}
END TARGET"
)
print
(
f
"PRED:
{
pred
}
"
)
print
(
"*"
*
80
)
def
process_results
(
self
,
doc
,
results
):
ll_true
,
ll_false
=
results
pred
=
ll_true
>
ll_false
gold
=
doc
[
"label"
]
return
{
"mcc"
:
(
gold
,
pred
)
"mcc"
:
(
target
,
pred
)
}
def
higher_is_better
(
self
):
...
...
@@ -97,7 +94,7 @@ class CoLA(Task):
}
class
SST
(
Task
):
class
SST
(
PromptSource
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"sst2"
...
...
@@ -119,27 +116,6 @@ class SST(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: Is this sentence positive or negative?
\n
Answer:"
.
format
(
general_detokenize
(
doc
[
"sentence"
]),
)
def
doc_to_target
(
self
,
doc
):
return
" {}"
.
format
({
1
:
"positive"
,
0
:
"negative"
}[
doc
[
"label"
]])
def
construct_requests
(
self
,
doc
,
ctx
):
ll_positive
,
_
=
rf
.
loglikelihood
(
ctx
,
" positive"
)
ll_negative
,
_
=
rf
.
loglikelihood
(
ctx
,
" negative"
)
return
ll_positive
,
ll_negative
def
process_results
(
self
,
doc
,
results
):
ll_positive
,
ll_negative
=
results
pred
=
ll_positive
>
ll_negative
gold
=
doc
[
"label"
]
return
{
"acc"
:
pred
==
gold
}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
...
...
@@ -154,7 +130,7 @@ class SST(Task):
# Inference Tasks
class
MNLI
(
Task
):
class
MNLI
(
PromptSource
Task
):
VERSION
=
0
DATASET_PATH
=
"glue"
DATASET_NAME
=
"mnli"
...
...
@@ -181,24 +157,6 @@ class MNLI(Task):
if
self
.
has_test_docs
():
return
self
.
dataset
[
"test_matched"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {} True, False or Neither?
\n
Answer:"
.
format
(
doc
[
"premise"
],
doc
[
"hypothesis"
].
strip
()
+
(
''
if
doc
[
"hypothesis"
].
strip
().
endswith
(
'.'
)
else
'.'
),
)
def
doc_to_target
(
self
,
doc
):
# True = entailment
# False = contradiction
# Neither = neutral
return
" {}"
.
format
({
0
:
"True"
,
1
:
"Neither"
,
2
:
"False"
}[
doc
[
"label"
]])
def
construct_requests
(
self
,
doc
,
ctx
):
ll_true
,
_
=
rf
.
loglikelihood
(
ctx
,
" True"
)
ll_neither
,
_
=
rf
.
loglikelihood
(
ctx
,
" Neither"
)
ll_false
,
_
=
rf
.
loglikelihood
(
ctx
,
" False"
)
return
ll_true
,
ll_neither
,
ll_false
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"label"
]
pred
=
np
.
argmax
(
results
)
...
...
@@ -251,22 +209,6 @@ class QNLI(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
{}
\n
Question: Does this response answer the question?
\n
Answer:"
.
format
(
doc
[
"question"
],
doc
[
"sentence"
],
)
def
doc_to_target
(
self
,
doc
):
# True = entailment
# False = not entailment
return
" {}"
.
format
({
0
:
"yes"
,
1
:
"no"
}[
doc
[
"label"
]])
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
" yes"
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
" no"
)
return
ll_yes
,
ll_no
def
process_results
(
self
,
doc
,
results
):
ll_yes
,
ll_no
=
results
pred
=
ll_no
>
ll_yes
...
...
@@ -342,14 +284,6 @@ class RTE(PromptSourceTask):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
# def process_results(self, doc, results):
# ll_true, ll_false = results
# pred = ll_false > ll_true
# gold = doc["label"]
# return {
# "acc": pred == gold
# }
def
higher_is_better
(
self
):
return
{
"acc"
:
True
...
...
@@ -386,20 +320,6 @@ class MRPC(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"Sentence 1: {}
\n
Sentence 2: {}
\n
Question: Do both sentences mean the same thing?
\n
Answer:"
.
format
(
general_detokenize
(
doc
[
"sentence1"
]),
general_detokenize
(
doc
[
"sentence2"
]),
)
def
doc_to_target
(
self
,
doc
):
return
" {}"
.
format
(
yesno
(
doc
[
"label"
]))
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
" yes"
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
" no"
)
return
ll_yes
,
ll_no
def
process_results
(
self
,
doc
,
results
):
ll_yes
,
ll_no
=
results
gold
=
doc
[
"label"
]
...
...
lm_eval/tasks/superglue.py
View file @
34f591af
...
...
@@ -12,7 +12,7 @@ TODO: WSC requires free-form generation.
import
numpy
as
np
import
sklearn
import
transformers.data.metrics.squad_metrics
as
squad_metrics
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
rf
,
PromptSource
Task
from
lm_eval.metrics
import
mean
,
acc_all
,
metric_max_over_ground_truths
,
yesno
from
lm_eval.utils
import
general_detokenize
...
...
@@ -32,7 +32,7 @@ _CITATION = """
"""
class
BoolQ
(
Task
):
class
BoolQ
(
PromptSource
Task
):
VERSION
=
1
DATASET_PATH
=
"super_glue"
DATASET_NAME
=
"boolq"
...
...
@@ -54,29 +54,6 @@ class BoolQ(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
f
"
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
?
\n
Answer:"
def
doc_to_target
(
self
,
doc
):
return
" "
+
yesno
(
doc
[
'label'
])
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
' yes'
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
' no'
)
return
ll_yes
,
ll_no
def
process_results
(
self
,
doc
,
results
):
ll_yes
,
ll_no
=
results
gold
=
doc
[
"label"
]
acc
=
1.
if
(
ll_yes
>
ll_no
)
==
gold
else
0.
return
{
"acc"
:
acc
}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
...
...
@@ -88,7 +65,7 @@ class BoolQ(Task):
}
class
CommitmentBank
(
Task
):
class
CommitmentBank
(
PromptSource
Task
):
VERSION
=
1
DATASET_PATH
=
"super_glue"
DATASET_NAME
=
"cb"
...
...
@@ -110,25 +87,6 @@ class CommitmentBank(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"{}
\n
Question: {}. True, False or Neither?
\n
Answer:"
.
format
(
doc
[
"premise"
],
doc
[
"hypothesis"
],
)
def
doc_to_target
(
self
,
doc
):
# True = entailment
# False = contradiction
# Neither = neutral
return
" {}"
.
format
({
0
:
"True"
,
1
:
"False"
,
2
:
"Neither"
}[
doc
[
"label"
]])
def
construct_requests
(
self
,
doc
,
ctx
):
ll_true
,
_
=
rf
.
loglikelihood
(
ctx
,
' True'
)
ll_false
,
_
=
rf
.
loglikelihood
(
ctx
,
' False'
)
ll_neither
,
_
=
rf
.
loglikelihood
(
ctx
,
' Neither'
)
return
ll_true
,
ll_false
,
ll_neither
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"label"
]
pred
=
np
.
argmax
(
results
)
...
...
@@ -163,7 +121,7 @@ class CommitmentBank(Task):
}
class
Copa
(
Task
):
class
Copa
(
PromptSource
Task
):
VERSION
=
0
DATASET_PATH
=
"super_glue"
DATASET_NAME
=
"copa"
...
...
@@ -185,28 +143,6 @@ class Copa(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
# Drop the period
connector
=
{
"cause"
:
"because"
,
"effect"
:
"therefore"
,
}[
doc
[
"question"
]]
return
doc
[
"premise"
].
strip
()[:
-
1
]
+
f
"
{
connector
}
"
def
doc_to_target
(
self
,
doc
):
correct_choice
=
doc
[
"choice1"
]
if
doc
[
"label"
]
==
0
else
doc
[
"choice2"
]
# Connect the sentences
return
" "
+
self
.
convert_choice
(
correct_choice
)
def
construct_requests
(
self
,
doc
,
ctx
):
choice1
=
" "
+
self
.
convert_choice
(
doc
[
"choice1"
])
choice2
=
" "
+
self
.
convert_choice
(
doc
[
"choice2"
])
ll_choice1
,
_
=
rf
.
loglikelihood
(
ctx
,
choice1
)
ll_choice2
,
_
=
rf
.
loglikelihood
(
ctx
,
choice2
)
return
ll_choice1
,
ll_choice2
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"label"
]
pred
=
np
.
argmax
(
results
)
...
...
@@ -231,7 +167,7 @@ class Copa(Task):
return
choice
[
0
].
lower
()
+
choice
[
1
:]
class
MultiRC
(
Task
):
class
MultiRC
(
PromptSource
Task
):
VERSION
=
1
DATASET_PATH
=
"super_glue"
DATASET_NAME
=
"multirc"
...
...
@@ -253,26 +189,6 @@ class MultiRC(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
f
"
{
doc
[
'paragraph'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
def
doc_to_target
(
self
,
doc
):
return
" "
+
self
.
format_answer
(
answer
=
doc
[
"answer"
],
label
=
doc
[
"label"
])
@
staticmethod
def
format_answer
(
answer
,
label
):
label_str
=
"yes"
if
label
else
"no"
return
f
"
{
answer
}
\n
Is the answer correct?
{
label_str
}
"
def
construct_requests
(
self
,
doc
,
ctx
):
true_choice
=
self
.
format_answer
(
answer
=
doc
[
"answer"
],
label
=
True
)
false_choice
=
self
.
format_answer
(
answer
=
doc
[
"answer"
],
label
=
False
)
ll_true_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
'
{
true_choice
}
'
)
ll_false_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
'
{
false_choice
}
'
)
return
ll_true_choice
,
ll_false_choice
def
process_results
(
self
,
doc
,
results
):
ll_true_choice
,
ll_false_choice
=
results
pred
=
ll_true_choice
>
ll_false_choice
...
...
@@ -291,7 +207,7 @@ class MultiRC(Task):
}
class
ReCoRD
(
Task
):
class
ReCoRD
(
PromptSource
Task
):
VERSION
=
0
DATASET_PATH
=
"super_glue"
DATASET_NAME
=
"record"
...
...
@@ -328,33 +244,13 @@ class ReCoRD(Task):
"answers"
:
sorted
(
list
(
set
(
doc
[
"answers"
]))),
}
def
doc_to_text
(
self
,
doc
):
initial_text
,
*
highlights
=
doc
[
"passage"
].
strip
().
split
(
"
\n
@highlight
\n
"
)
text
=
initial_text
+
"
\n\n
"
for
highlight
in
highlights
:
text
+=
f
" -
{
highlight
}
.
\n
"
return
text
@
classmethod
def
format_answer
(
cls
,
query
,
entity
):
return
f
' -
{
query
}
'
.
replace
(
"@placeholder"
,
entity
)
def
doc_to_target
(
self
,
doc
):
# We only output the first correct entity in a doc
return
self
.
format_answer
(
query
=
doc
[
"query"
],
entity
=
doc
[
"answers"
][
0
])
def
construct_requests
(
self
,
doc
,
ctx
):
requests
=
[
rf
.
loglikelihood
(
ctx
,
self
.
format_answer
(
query
=
doc
[
"query"
],
entity
=
entity
))
for
entity
in
doc
[
"entities"
]
]
return
requests
def
process_results
(
self
,
doc
,
results
):
# ReCoRD's evaluation is actually deceptively simple:
# - Pick the maximum likelihood prediction entity
# - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples
# TODO (jon-tow): Look at result
max_idx
=
np
.
argmax
(
np
.
array
([
result
[
0
]
for
result
in
results
]))
prediction
=
doc
[
"entities"
][
max_idx
]
...
...
@@ -380,7 +276,7 @@ class ReCoRD(Task):
}
class
WordsInContext
(
Task
):
class
WordsInContext
(
PromptSource
Task
):
VERSION
=
0
DATASET_PATH
=
"super_glue"
DATASET_NAME
=
"wic"
...
...
@@ -402,33 +298,6 @@ class WordsInContext(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
return
"Sentence 1: {}
\n
Sentence 2: {}
\n
Question: Is the word '{}' used in the same way in the"
\
" two sentences above?
\n
Answer:"
.
format
(
doc
[
"sentence1"
],
doc
[
"sentence2"
],
doc
[
"sentence1"
][
doc
[
"start1"
]:
doc
[
"end1"
]],
)
def
doc_to_target
(
self
,
doc
):
return
" {}"
.
format
({
0
:
"no"
,
1
:
"yes"
}[
doc
[
"label"
]])
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
' yes'
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
' no'
)
return
ll_yes
,
ll_no
def
process_results
(
self
,
doc
,
results
):
ll_yes
,
ll_no
=
results
gold
=
doc
[
"label"
]
acc
=
1.
if
(
ll_yes
>
ll_no
)
==
gold
else
0.
return
{
"acc"
:
acc
}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
...
...
@@ -440,7 +309,7 @@ class WordsInContext(Task):
}
class
SGWinogradSchemaChallenge
(
Task
):
class
SGWinogradSchemaChallenge
(
PromptSource
Task
):
VERSION
=
0
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task.
...
...
@@ -470,41 +339,6 @@ class SGWinogradSchemaChallenge(Task):
def
validation_docs
(
self
):
return
self
.
dataset
[
"validation"
]
def
doc_to_text
(
self
,
doc
):
raw_passage
=
doc
[
"text"
]
# NOTE: HuggingFace span indices are word-based not character-based.
pre
=
" "
.
join
(
raw_passage
.
split
()[:
doc
[
"span2_index"
]])
post
=
raw_passage
[
len
(
pre
)
+
len
(
doc
[
"span2_text"
])
+
1
:]
passage
=
general_detokenize
(
pre
+
" *{}*"
.
format
(
doc
[
'span2_text'
])
+
post
)
noun
=
doc
[
"span1_text"
]
pronoun
=
doc
[
"span2_text"
]
text
=
(
f
"Passage:
{
passage
}
\n
"
+
f
"Question: In the passage above, does the pronoun
\"
*
{
pronoun
}
*
\"
refer to
\"
*
{
noun
}
*
\"
?
\n
"
+
"Answer:"
)
return
text
def
doc_to_target
(
self
,
doc
):
return
" "
+
yesno
(
doc
[
'label'
])
def
construct_requests
(
self
,
doc
,
ctx
):
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
' yes'
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
' no'
)
return
ll_yes
,
ll_no
def
process_results
(
self
,
doc
,
results
):
ll_yes
,
ll_no
=
results
gold
=
doc
[
"label"
]
acc
=
1.
if
(
ll_yes
>
ll_no
)
==
gold
else
0.
return
{
"acc"
:
acc
}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment