Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1297e342
Commit
1297e342
authored
Feb 12, 2021
by
&
Browse files
add tasks to registry
parent
12ba8426
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
18 deletions
+45
-18
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+4
-0
lm_eval/tasks/translation.py
lm_eval/tasks/translation.py
+41
-18
No files found.
lm_eval/tasks/__init__.py
View file @
1297e342
...
@@ -21,6 +21,7 @@ from . import pubmedqa
...
@@ -21,6 +21,7 @@ from . import pubmedqa
from
.
import
sciq
from
.
import
sciq
from
.
import
webqs
from
.
import
webqs
from
.
import
qa4mre
from
.
import
qa4mre
from
.
import
translation
TASK_REGISTRY
=
{
TASK_REGISTRY
=
{
...
@@ -85,6 +86,9 @@ TASK_REGISTRY = {
...
@@ -85,6 +86,9 @@ TASK_REGISTRY = {
"arithmetic_2dm"
:
arithmetic
.
Arithmetic2DMultiplication
,
"arithmetic_2dm"
:
arithmetic
.
Arithmetic2DMultiplication
,
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
**
translation
.
create_tasks_from_benchmarks
(
translation
.
selected_benchmarks
)
}
}
...
...
lm_eval/tasks/translation.py
View file @
1297e342
...
@@ -26,26 +26,26 @@ sacrebleu_datasets = sacrebleu.DATASETS
...
@@ -26,26 +26,26 @@ sacrebleu_datasets = sacrebleu.DATASETS
# 6 total
# 6 total
gpt3_
test
s
=
{
gpt3_
benchmark
s
=
{
"wmt14"
:
[
'en-fr'
,
'fr-en'
],
# French
"wmt14"
:
[
'en-fr'
,
'fr-en'
],
# French
"wmt16"
:
[
'en-ro'
,
'ro-en'
,
'de-en'
,
'en-de'
],
# German, Romanian
"wmt16"
:
[
'en-ro'
,
'ro-en'
,
'de-en'
,
'en-de'
],
# German, Romanian
}
}
# 14 total
# 14 total
selected_
test
s
=
{
selected_
benchmark
s
=
{
**
gpt3_
test
s
,
**
gpt3_
benchmark
s
,
"wmt20"
:
[
'fr-de'
,
'de-fr'
,
'en-ru'
,
'ru-en'
,
'en-iu'
,
'iu-en'
],
# French, German, Russian, Inuit
"wmt20"
:
[
'fr-de'
,
'de-fr'
,
'en-ru'
,
'ru-en'
,
'en-iu'
,
'iu-en'
],
# French, German, Russian, Inuit
"iwslt17"
:
[
'en-ar'
,
'ar-en'
]
# Arabic
"iwslt17"
:
[
'en-ar'
,
'ar-en'
]
# Arabic
}
}
# 319 total
# 319 total
all_
test
s
=
{
all_
benchmark
s
=
{
ts
:
sacrebleu
.
get_langpairs_for_testset
(
ts
)
ts
:
sacrebleu
.
get_langpairs_for_testset
(
ts
)
for
ts
in
sacrebleu
.
get_available_testsets
()
for
ts
in
sacrebleu
.
get_available_testsets
()
}
}
available_tests
=
{
available_tests
=
{
"gpt3_tests"
:
gpt3_
test
s
,
"gpt3_tests"
:
gpt3_
benchmark
s
,
"selected_tests"
:
selected_
test
s
,
"selected_tests"
:
selected_
benchmark
s
,
"all_tests"
:
all_
test
s
"all_tests"
:
all_
benchmark
s
}
}
...
@@ -53,6 +53,14 @@ available_tests = {
...
@@ -53,6 +53,14 @@ available_tests = {
# Tasks
# Tasks
########################################
########################################
def
create_tasks_from_benchmarks
(
benchmark_dict
):
"""Creates a dictionary of tasks from a dict {dataset: [lang_pair, ...]}"""
return
{
f
"
{
dataset
}
-
{
language_pair
}
"
:
create_translation_task
(
dataset
,
language_pair
)
for
dataset
,
language_pairs
in
benchmark_dict
.
items
()
for
language_pair
in
language_pairs
}
def
create_translation_task
(
dataset
,
language_pair
):
def
create_translation_task
(
dataset
,
language_pair
):
class
TranslationTask
(
GeneralTranslationTask
):
class
TranslationTask
(
GeneralTranslationTask
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -125,10 +133,11 @@ class GeneralTranslationTask(Task):
...
@@ -125,10 +133,11 @@ class GeneralTranslationTask(Task):
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
# These metrics are corpus-level not sentence level, so we'll hide the
# These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method
# results in this dict and compute the corpus score in the aggregate method
ref_pred
=
(
doc
[
"ref"
],
results
)
return
{
return
{
"bleu"
:
(
doc
[
"ref"
],
results
)
,
"bleu"
:
ref_pred
,
"chrf"
:
(
doc
[
"ref"
],
results
)
,
"chrf"
:
ref_pred
,
"ter"
:
(
doc
[
"ref"
],
results
)
,
"ter"
:
ref_pred
,
}
}
def
aggregation
(
self
):
def
aggregation
(
self
):
...
@@ -157,7 +166,9 @@ class GeneralTranslationTask(Task):
...
@@ -157,7 +166,9 @@ class GeneralTranslationTask(Task):
def
fewshot_description
(
self
):
def
fewshot_description
(
self
):
language_codes
=
self
.
sacrebleu_language_pair
.
split
(
"-"
)
language_codes
=
self
.
sacrebleu_language_pair
.
split
(
"-"
)
return
f
"Translate
{
code_to_language
(
language_codes
[
0
])
}
to
{
language_codes
[
1
]
}
."
src_lang
=
code_to_language
(
language_codes
[
0
])
tar_lang
=
code_to_language
(
language_codes
[
1
])
return
f
"Translate these
{
src_lang
}
phrases to
{
tar_lang
}
."
# TODO This should be something like
# TODO This should be something like
# French: {src_line}
# French: {src_line}
...
@@ -165,6 +176,12 @@ class GeneralTranslationTask(Task):
...
@@ -165,6 +176,12 @@ class GeneralTranslationTask(Task):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
):
return
""
return
""
def
__str__
(
self
):
language_codes
=
self
.
sacrebleu_language_pair
.
split
(
"-"
)
src_lang
=
code_to_language
(
language_codes
[
0
])
tar_lang
=
code_to_language
(
language_codes
[
1
])
return
f
"
{
self
.
sacrebleu_dataset
.
upper
()
}
{
src_lang
}
to
{
tar_lang
}
Task"
########################################
########################################
# Util
# Util
...
@@ -173,7 +190,7 @@ class GeneralTranslationTask(Task):
...
@@ -173,7 +190,7 @@ class GeneralTranslationTask(Task):
def
code_to_language
(
code
):
def
code_to_language
(
code
):
# key is alpha_2 or alpha_3 depending on the code length
# key is alpha_2 or alpha_3 depending on the code length
language_tuple
=
pycountry
.
languages
.
get
({
f
"alpha_
{
len
(
code
)
}
"
:
code
})
language_tuple
=
pycountry
.
languages
.
get
(
**
{
f
"alpha_
{
len
(
code
)
}
"
:
code
})
return
language_tuple
.
name
return
language_tuple
.
name
def
print_available_tests
():
def
print_available_tests
():
...
@@ -181,14 +198,20 @@ def print_available_tests():
...
@@ -181,14 +198,20 @@ def print_available_tests():
def
main
():
def
main
():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
# print_available_tests()
# print(len(sacrebleu.print_test_set("wmt14", "fr-en", "src")))
# sacrebleu.print_test_set("wmt14", "fr-en", "src")
# print(GeneralTranslationTask("wmt14", "fr-en"))
print
(
sum
(
# # Print number of benchmarks
[
len
(
sacrebleu
.
get_langpairs_for_testset
(
ts
))
for
ts
in
sacrebleu
.
get_available_testsets
()])
# print(sum([
)
# len(sacrebleu.get_langpairs_for_testset(ts))
pass
# for ts in sacrebleu.get_available_testsets()
# ]))
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment