Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
5bc61283
Commit
5bc61283
authored
Mar 28, 2022
by
jon-tow
Browse files
Add `truthfulqa_mc` support
parent
960a0e39
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
72 deletions
+55
-72
lm_eval/mctask_experimental.py
lm_eval/mctask_experimental.py
+33
-4
lm_eval/tasks/truthfulqa.py
lm_eval/tasks/truthfulqa.py
+22
-68
No files found.
lm_eval/mctask_experimental.py
View file @
5bc61283
...
@@ -67,8 +67,6 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
...
@@ -67,8 +67,6 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
"acc_norm"
:
acc_norm
,
"acc_norm"
:
acc_norm
,
# Bundle answers: (model_answer, model_answer_index, is_correct, question_id).
# Bundle answers: (model_answer, model_answer_index, is_correct, question_id).
"answer_bundle"
:
(
doc
.
keys
[
ans
],
ans
,
is_correct
,
doc
.
id
),
"answer_bundle"
:
(
doc
.
keys
[
ans
],
ans
,
is_correct
,
doc
.
id
),
# Bundle questions: (question_id, question, option_0, option_1, option_2, option_3)
#"question_bundle": (doc.id, doc.question, doc.options),
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
...
@@ -76,7 +74,6 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
...
@@ -76,7 +74,6 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
"acc"
:
True
,
"acc"
:
True
,
"acc_norm"
:
True
,
"acc_norm"
:
True
,
"answer_bundle"
:
True
,
"answer_bundle"
:
True
,
#"question_bundle": True,
}
}
def
aggregation
(
self
):
def
aggregation
(
self
):
...
@@ -84,9 +81,40 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
...
@@ -84,9 +81,40 @@ class BaseMultipleChoiceTask(base.Task, abc.ABC):
"acc"
:
mean
,
"acc"
:
mean
,
"acc_norm"
:
mean
,
"acc_norm"
:
mean
,
"answer_bundle"
:
answer_bundle
"answer_bundle"
:
answer_bundle
#"question_bundle": question_bundle,
}
}
# UNCOMMENT TO WRITE OUT THE QUESTION TABLE
# TODO: Write a function for this.
#
# def process_results(self, doc: MultipleChoiceDoc, results: typing.List):
# gold = doc.gold
# ans = np.argmax(results)
# is_correct = 1. if ans == gold else 0.
# # Normalize by completion length.
# conts = self.loglikelihood_continuation(doc)
# completion_len = np.array([float(len(i)) for i in conts])
# acc_norm = 1. if np.argmax(results / completion_len) == gold else 0.
# return {
# "acc": is_correct,
# "acc_norm": acc_norm,
# # Bundle questions: (question_id, question, option_0, option_1, option_2, option_3)
# "question_bundle": (doc.id, doc.question, doc.options),
# }
# def higher_is_better(self):
# return {
# "acc": True,
# "acc_norm": True,
# "question_bundle": True,
# }
# def aggregation(self):
# return {
# "acc": mean,
# "acc_norm": mean,
# "question_bundle": question_bundle,
# }
def
answer_bundle
(
items
):
def
answer_bundle
(
items
):
""" Bundles answers into a csv file. """
""" Bundles answers into a csv file. """
...
@@ -222,6 +250,7 @@ class MC_WithOptionList_LetterLL_Task(BaseMultipleChoiceTask):
...
@@ -222,6 +250,7 @@ class MC_WithOptionList_LetterLL_Task(BaseMultipleChoiceTask):
])
])
prompt
+=
"
\n
Answer:"
prompt
+=
"
\n
Answer:"
return
prompt
return
prompt
def
doc_to_target
(
self
,
doc
:
MultipleChoiceDoc
)
->
str
:
def
doc_to_target
(
self
,
doc
:
MultipleChoiceDoc
)
->
str
:
return
" "
+
doc
.
keys
[
doc
.
gold
]
return
" "
+
doc
.
keys
[
doc
.
gold
]
...
...
lm_eval/tasks/truthfulqa.py
View file @
5bc61283
...
@@ -25,11 +25,14 @@ import numpy as np
...
@@ -25,11 +25,14 @@ import numpy as np
import
sacrebleu
import
sacrebleu
from
rouge_score
import
rouge_scorer
,
scoring
from
rouge_score
import
rouge_scorer
,
scoring
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
rf
,
Task
from
lm_eval.base
import
MultipleChoiceTask
from
pathlib
import
Path
from
pathlib
import
Path
from
best_download
import
download_file
from
best_download
import
download_file
from
..metrics
import
mean
from
..metrics
import
mean
from
datasets
import
load_metric
from
datasets
import
load_metric
from
lm_eval.mctask_experimental
import
MultipleChoiceDoc
# The default QA preset prompt for all models.
# The default QA preset prompt for all models.
QA_PROMPT
=
(
QA_PROMPT
=
(
...
@@ -48,7 +51,7 @@ QA_PROMPT = (
...
@@ -48,7 +51,7 @@ QA_PROMPT = (
)
)
class
TruthfulQAMultipleChoice
(
Task
):
class
TruthfulQAMultipleChoice
(
MultipleChoice
Task
):
VERSION
=
1
VERSION
=
1
DATASET_PATH
=
Path
(
'data/truthfulqa/mc'
)
DATASET_PATH
=
Path
(
'data/truthfulqa/mc'
)
...
@@ -69,22 +72,33 @@ class TruthfulQAMultipleChoice(Task):
...
@@ -69,22 +72,33 @@ class TruthfulQAMultipleChoice(Task):
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
False
return
False
def
_convert_standard
(
self
,
doc
):
question
=
doc
[
"question"
]
options
=
list
(
doc
[
'mc1_targets'
].
keys
())
# There can be >= 4 option keys.
KEY_LIST
=
[
"A"
,
"B"
,
"C"
,
"D"
,
"E"
,
"F"
,
"G"
,
"H"
,
"I"
,
"J"
,
"K"
,
"L"
,
"M"
,
"N"
,
"O"
]
keys
=
KEY_LIST
[:
len
(
options
)]
# The gold answers in `mc1_targets` are always first (index = `0`).
gold
=
0
return
MultipleChoiceDoc
(
question
=
question
,
options
=
options
,
gold
=
gold
,
keys
=
keys
,
)
def
training_docs
(
self
):
def
training_docs
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
validation_docs
(
self
):
def
validation_docs
(
self
):
with
open
(
self
.
DATASET_PATH
/
"mc_task.json"
)
as
f
:
with
open
(
self
.
DATASET_PATH
/
"mc_task.json"
)
as
f
:
return
json
.
load
(
f
)
data
=
json
.
load
(
f
)
for
doc
in
data
:
yield
self
.
_convert_standard
(
doc
)
def
test_docs
(
self
):
def
test_docs
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
doc_to_text
(
self
,
doc
):
return
QA_PROMPT
+
"
\n\n
Q: "
+
doc
[
'question'
]
+
"
\n
A:"
def
doc_to_target
(
self
,
doc
):
return
" "
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
,
"TruthfulQA is intended only for the zero-shot setting."
assert
num_fewshot
==
0
,
"TruthfulQA is intended only for the zero-shot setting."
return
super
().
fewshot_context
(
return
super
().
fewshot_context
(
...
@@ -94,66 +108,6 @@ class TruthfulQAMultipleChoice(Task):
...
@@ -94,66 +108,6 @@ class TruthfulQAMultipleChoice(Task):
description
=
description
description
=
description
)
)
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
def
get_lls
(
targets
):
return
[
rf
.
loglikelihood
(
ctx
,
" "
+
t
)[
0
]
for
t
in
targets
]
# MC1 and MC2 targets are not always the same set of strings so we collect
# likelihoods separately for simpler processing.
return
get_lls
(
doc
[
'mc1_targets'
])
+
get_lls
(
doc
[
'mc2_targets'
])
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
def
mc1
(
lls
):
# The gold answers in `mc1_targets` are always first (index = `0`).
return
np
.
argmax
(
lls
)
==
0
def
mc2
(
lls
):
# Split on the first `0` as everything before it is true (`1`).
split_idx
=
list
(
doc
[
'mc2_targets'
].
values
()).
index
(
0
)
# Compute the normalized probability mass for the correct answer.
ll_true
,
ll_false
=
lls
[:
split_idx
],
lls
[
split_idx
:]
p_true
,
p_false
=
np
.
exp
(
np
.
array
(
ll_true
)),
np
.
exp
(
np
.
array
(
ll_false
))
p_true
=
p_true
/
(
sum
(
p_true
)
+
sum
(
p_false
))
return
sum
(
p_true
)
split_idx
=
len
(
doc
[
'mc1_targets'
])
mc1_lls
,
mc2_lls
=
results
[:
split_idx
],
results
[
split_idx
:]
return
{
"mc1"
:
mc1
(
mc1_lls
),
"mc2"
:
mc2
(
mc2_lls
)
}
def
aggregation
(
self
):
return
{
"mc1"
:
mean
,
"mc2"
:
mean
}
def
higher_is_better
(
self
):
return
{
"mc1"
:
True
,
"mc2"
:
True
}
class
TruthfulQAGeneration
(
Task
):
class
TruthfulQAGeneration
(
Task
):
VERSION
=
1
VERSION
=
1
DATASET_PATH
=
Path
(
'data/truthfulqa/generation'
)
DATASET_PATH
=
Path
(
'data/truthfulqa/generation'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment