Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
2e43c8d7
Commit
2e43c8d7
authored
Jan 19, 2025
by
Baber
Browse files
add qa_hotpot
parent
c8a3f2e2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
5 deletions
+40
-5
lm_eval/tasks/ruler/qa_hotpot.yaml
lm_eval/tasks/ruler/qa_hotpot.yaml
+3
-0
lm_eval/tasks/ruler/qa_squad.yaml
lm_eval/tasks/ruler/qa_squad.yaml
+2
-2
lm_eval/tasks/ruler/qa_utils.py
lm_eval/tasks/ruler/qa_utils.py
+35
-3
No files found.
lm_eval/tasks/ruler/qa_hotpot.yaml
0 → 100644
View file @
2e43c8d7
include
:
qa_squad.yaml
task
:
ruler_qa_hotpot
download_dataset
:
!function
qa_utils.get_hotpotqa
lm_eval/tasks/ruler/qa.yaml
→
lm_eval/tasks/ruler/qa
_squad
.yaml
View file @
2e43c8d7
include
:
niah_1.yaml
include
:
niah_1.yaml
task
:
ruler_qa
task
:
ruler_qa
_squad
download_dataset
:
!function
qa_utils.get_
qa_dataset
download_dataset
:
!function
qa_utils.get_
squad
test_split
:
test
test_split
:
test
generation_kwargs
:
generation_kwargs
:
do_sample
:
false
do_sample
:
false
...
...
lm_eval/tasks/ruler/qa_utils.py
View file @
2e43c8d7
import
itertools
# noqa: I001
import
itertools
# noqa: I001
import
random
import
random
from
functools
import
cache
from
functools
import
cache
,
partial
import
datasets
import
datasets
import
requests
import
requests
...
@@ -27,6 +27,7 @@ def download_json(url):
...
@@ -27,6 +27,7 @@ def download_json(url):
return
data
return
data
@
cache
def
read_squad
(
url
=
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"
):
def
read_squad
(
url
=
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"
):
data
=
download_json
(
url
)
data
=
download_json
(
url
)
total_docs
=
[
p
[
"context"
]
for
d
in
data
[
"data"
]
for
p
in
d
[
"paragraphs"
]]
total_docs
=
[
p
[
"context"
]
for
d
in
data
[
"data"
]
for
p
in
d
[
"paragraphs"
]]
...
@@ -55,6 +56,30 @@ def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.
...
@@ -55,6 +56,30 @@ def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.
return
total_qas
,
total_docs
return
total_qas
,
total_docs
@
cache
def
read_hotpotqa
(
url
=
"http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json"
,
):
data
=
download_json
(
url
)
total_docs
=
[
f
"
{
t
}
\n
{
''
.
join
(
p
)
}
"
for
d
in
data
for
t
,
p
in
d
[
"context"
]]
total_docs
=
sorted
(
list
(
set
(
total_docs
)))
total_docs_dict
=
{
c
:
idx
for
idx
,
c
in
enumerate
(
total_docs
)}
total_qas
=
[]
for
d
in
data
:
total_qas
.
append
(
{
"query"
:
d
[
"question"
],
"outputs"
:
[
d
[
"answer"
]],
"context"
:
[
total_docs_dict
[
f
"
{
t
}
\n
{
''
.
join
(
p
)
}
"
]
for
t
,
p
in
d
[
"context"
]
],
}
)
return
total_qas
,
total_docs
def
generate_input_output
(
def
generate_input_output
(
index
:
int
,
num_docs
:
int
,
qas
:
list
[
dict
],
docs
:
list
[
str
]
index
:
int
,
num_docs
:
int
,
qas
:
list
[
dict
],
docs
:
list
[
str
]
)
->
tuple
[
str
,
list
[
str
]]:
)
->
tuple
[
str
,
list
[
str
]]:
...
@@ -176,10 +201,13 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
...
@@ -176,10 +201,13 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
return
write_jsons
return
write_jsons
def
get_qa_dataset
(
**
kwargs
):
def
get_qa_dataset
(
ds
,
**
kwargs
):
kwargs
=
kwargs
.
get
(
"metadata"
,
{})
kwargs
=
kwargs
.
get
(
"metadata"
,
{})
pretrained
=
kwargs
.
get
(
"tokenizer"
,
kwargs
.
get
(
"pretrained"
,
{}))
pretrained
=
kwargs
.
get
(
"tokenizer"
,
kwargs
.
get
(
"pretrained"
,
{}))
qas
,
docs
=
read_squad
()
if
ds
==
"squad"
:
qas
,
docs
=
read_squad
()
else
:
qas
,
docs
=
read_hotpotqa
()
df
=
(
df
=
(
get_dataset
(
pretrained
,
docs
=
docs
,
qas
=
qas
,
max_seq_length
=
seq
)
get_dataset
(
pretrained
,
docs
=
docs
,
qas
=
qas
,
max_seq_length
=
seq
)
for
seq
in
SEQ_LENGTHS
for
seq
in
SEQ_LENGTHS
...
@@ -190,3 +218,7 @@ def get_qa_dataset(**kwargs):
...
@@ -190,3 +218,7 @@ def get_qa_dataset(**kwargs):
list
(
itertools
.
chain
.
from_iterable
(
df
)),
split
=
datasets
.
Split
.
TEST
list
(
itertools
.
chain
.
from_iterable
(
df
)),
split
=
datasets
.
Split
.
TEST
)
)
}
}
get_squad
=
partial
(
get_qa_dataset
,
"squad"
)
get_hotpotqa
=
partial
(
get_qa_dataset
,
"hotpotqa"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment