Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
224b0854
Commit
224b0854
authored
Jun 08, 2023
by
lintangsutawika
Browse files
change import origin
parent
8171906d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
9 deletions
+22
-9
lm_eval/tasks/triviaqa.py
lm_eval/tasks/triviaqa.py
+22
-9
No files found.
lm_eval/tasks/triviaqa.py
View file @
224b0854
...
@@ -10,11 +10,12 @@ high quality distant supervision for answering the questions.
...
@@ -10,11 +10,12 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/
Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
"""
import
inspect
import
inspect
# import lm_eval.datasets.triviaqa.triviaqa
# import lm_eval.datasets.triviaqa.triviaqa
import
string
import
string
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
Task
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.regist
e
r
import
register_task
from
lm_eval.api.registr
y
import
register_task
from
lm_eval.api.metrics
import
mean
from
lm_eval.api.metrics
import
mean
_CITATION
=
"""
_CITATION
=
"""
...
@@ -29,10 +30,11 @@ _CITATION = """
...
@@ -29,10 +30,11 @@ _CITATION = """
}
}
"""
"""
@
register_task
(
"triviaqa"
)
@
register_task
(
"triviaqa"
)
class
TriviaQA
(
Task
):
class
TriviaQA
(
Task
):
VERSION
=
1
VERSION
=
1
DATASET_PATH
=
"trivia_qa"
#
inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_PATH
=
"trivia_qa"
#
inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME
=
"unfiltered.nocontext"
DATASET_NAME
=
"unfiltered.nocontext"
OUTPUT_TYPE
=
"greedy_until"
OUTPUT_TYPE
=
"greedy_until"
...
@@ -90,18 +92,29 @@ class TriviaQA(Task):
...
@@ -90,18 +92,29 @@ class TriviaQA(Task):
continuation
=
Instance
(
continuation
=
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
doc
=
doc
,
arguments
=
(
ctx
,
{
arguments
=
(
"until"
:
[
"
\n
"
,
"."
,
","
],
ctx
,
"do_sample"
:
False
,
{
}),
"until"
:
[
"
\n
"
,
"."
,
","
],
"do_sample"
:
False
,
},
),
idx
=
0
,
idx
=
0
,
**
kwargs
,
**
kwargs
,
)
)
return
continuation
return
continuation
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
continuation
=
results
[
0
].
strip
().
lower
().
translate
(
str
.
maketrans
(
''
,
''
,
string
.
punctuation
))
continuation
=
(
list_of_candidates
=
[
alias
.
lower
().
translate
(
str
.
maketrans
(
''
,
''
,
string
.
punctuation
))
for
alias
in
self
.
_remove_prefixes
(
doc
[
"answer"
][
"aliases"
])]
results
[
0
]
.
strip
()
.
lower
()
.
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
)
list_of_candidates
=
[
alias
.
lower
().
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
for
alias
in
self
.
_remove_prefixes
(
doc
[
"answer"
][
"aliases"
])
]
return
{
"em"
:
float
(
continuation
in
list_of_candidates
)}
return
{
"em"
:
float
(
continuation
in
list_of_candidates
)}
def
aggregation
(
self
):
def
aggregation
(
self
):
...
@@ -110,4 +123,4 @@ class TriviaQA(Task):
...
@@ -110,4 +123,4 @@ class TriviaQA(Task):
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
{
"em"
:
True
}
return
{
"em"
:
True
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment