Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
44f03593
Unverified
Commit
44f03593
authored
Apr 10, 2021
by
Leo Gao
Committed by
GitHub
Apr 10, 2021
Browse files
Merge pull request #176 from EleutherAI/per_char_agg
Do per character loss aggregation for multiple choice tasks
parents
fd26ef16
1ebf41d3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
26 deletions
+31
-26
lm_eval/base.py
lm_eval/base.py
+8
-3
lm_eval/tasks/piqa.py
lm_eval/tasks/piqa.py
+23
-23
No files found.
lm_eval/base.py
View file @
44f03593
...
...
@@ -226,19 +226,24 @@ class MultipleChoiceTask(Task):
gold
=
doc
[
"gold"
]
acc
=
1.
if
np
.
argmax
(
results
)
==
gold
else
0.
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
doc
[
"choices"
]])
acc_norm
=
1.
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.
return
{
"acc"
:
acc
"acc"
:
acc
,
"acc_norm"
:
acc_norm
,
}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
"acc"
:
True
,
"acc_norm"
:
True
,
}
def
aggregation
(
self
):
return
{
"acc"
:
mean
"acc"
:
mean
,
"acc_norm"
:
mean
,
}
...
...
lm_eval/tasks/piqa.py
View file @
44f03593
import
numpy
as
np
from
lm_eval.base
import
rf
from
lm_eval.base
import
MultipleChoiceTask
,
rf
from
..metrics
import
mean
from
.
common
import
HFTask
class
PiQA
(
HFTask
):
class
PiQA
(
HFTask
,
MultipleChoiceTask
):
DATASET_PATH
=
"piqa"
DATASET_NAME
=
None
...
...
@@ -21,29 +21,29 @@ class PiQA(HFTask):
# TODO: figure out fewshot description
return
""
def
doc_to_text
(
self
,
doc
):
return
"Question: "
+
doc
[
"goal"
]
+
"
\n
Answer:"
def
_convert_standard
(
self
,
doc
):
out_doc
=
{
"goal"
:
doc
[
"goal"
],
"choices"
:
[
doc
[
"sol1"
],
doc
[
"sol2"
]],
"gold"
:
doc
[
"label"
],
}
return
out_doc
def
doc_to_target
(
self
,
doc
):
solutions
=
[
doc
[
"sol1"
],
doc
[
"sol2"
]]
return
" "
+
solutions
[
doc
[
"label"
]]
def
_load_docs
(
self
,
doc
s
):
for
record
in
docs
:
yield
self
.
_convert_standard
(
record
)
def
construct_requests
(
self
,
doc
,
ctx
):
ll_1
,
_
=
rf
.
loglikelihood
(
ctx
,
" "
+
doc
[
'sol1'
])
ll_2
,
_
=
rf
.
loglikelihood
(
ctx
,
" "
+
doc
[
'sol2'
])
return
ll_1
,
ll_2
def
training_docs
(
self
):
docs
=
super
().
training_docs
()
return
self
.
_load_docs
(
docs
)
def
process_results
(
self
,
doc
,
results
):
return
{
'acc'
:
np
.
argmax
(
results
)
==
doc
[
"label"
]
}
def
validation_docs
(
self
):
docs
=
super
().
validation_docs
()
return
self
.
_load_docs
(
docs
)
def
aggregation
(
self
):
return
{
'acc'
:
mean
}
def
test_docs
(
self
):
docs
=
super
().
test_docs
()
return
self
.
_load_docs
(
docs
)
def
higher_is_better
(
self
):
return
{
'acc'
:
True
}
def
doc_to_text
(
self
,
doc
):
return
"Question: "
+
doc
[
"goal"
]
+
"
\n
Answer:"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment