Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1e5194a2
Unverified
Commit
1e5194a2
authored
Feb 14, 2021
by
Leo Gao
Committed by
GitHub
Feb 14, 2021
Browse files
Merge pull request #153 from EleutherAI/translation-v2
Translation v2
parents
758b9e3c
d0a301cc
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
38 deletions
+57
-38
lm_eval/metrics.py
lm_eval/metrics.py
+9
-4
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+35
-2
lm_eval/tasks/translation.py
lm_eval/tasks/translation.py
+13
-32
No files found.
lm_eval/metrics.py
View file @
1e5194a2
import
math
import
math
from
collections
import
Iterable
from
pprint
import
pprint
from
pprint
import
pprint
import
numpy
as
np
import
numpy
as
np
...
@@ -107,6 +108,10 @@ def ter(items):
...
@@ -107,6 +108,10 @@ def ter(items):
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
def
is_non_str_iterable
(
obj
):
return
isinstance
(
obj
,
Iterable
)
and
not
isinstance
(
obj
,
str
)
def
_sacreformat
(
refs
,
preds
):
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# Sacrebleu expects (List[str], List[List[str])
...
@@ -118,17 +123,17 @@ def _sacreformat(refs, preds):
...
@@ -118,17 +123,17 @@ def _sacreformat(refs, preds):
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if
not
is
instance
(
refs
,
list
):
if
not
is
_non_str_iterable
(
refs
):
refs
=
list
(
refs
)
refs
=
list
(
refs
)
if
not
is
instance
(
refs
[
0
],
list
):
if
not
is
_non_str_iterable
(
refs
):
refs
=
[[
ref
]
for
ref
in
refs
]
refs
=
[[
ref
]
for
ref
in
refs
]
refs
=
list
(
zip
(
*
refs
))
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if
not
is
instanc
e
(
preds
,
list
):
if
not
is
_non_str_iterabl
e
(
preds
):
preds
=
list
(
preds
)
preds
=
list
(
preds
)
if
is
instanc
e
(
preds
[
0
]
,
list
):
if
is
_non_str_iterabl
e
(
preds
[
0
]):
assert
len
(
preds
[
0
])
==
1
,
f
"Pred must be a str, was
{
preds
[
0
]
}
"
assert
len
(
preds
[
0
])
==
1
,
f
"Pred must be a str, was
{
preds
[
0
]
}
"
preds
=
[
pred
[
0
]
for
pred
in
preds
]
preds
=
[
pred
[
0
]
for
pred
in
preds
]
...
...
lm_eval/tasks/__init__.py
View file @
1e5194a2
from
pprint
import
pprint
from
pprint
import
pprint
import
sacrebleu
from
.
import
superglue
from
.
import
superglue
from
.
import
glue
from
.
import
glue
from
.
import
arc
from
.
import
arc
...
@@ -27,6 +29,36 @@ from . import translation
...
@@ -27,6 +29,36 @@ from . import translation
from
.
import
headqa
from
.
import
headqa
from
.
import
mathqa
from
.
import
mathqa
########################################
# Translation tasks
########################################
# 6 total
gpt3_translation_benchmarks
=
{
"wmt14"
:
[
'en-fr'
,
'fr-en'
],
# French
"wmt16"
:
[
'en-ro'
,
'ro-en'
,
'de-en'
,
'en-de'
],
# German, Romanian
}
# 28 total
selected_translation_benchmarks
=
{
**
gpt3_translation_benchmarks
,
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"iwslt17"
:
[
'en-ar'
,
'ar-en'
]
# Arabic
}
# 319 total
all_translation_benchmarks
=
{
ts
:
sacrebleu
.
get_langpairs_for_testset
(
ts
)
for
ts
in
sacrebleu
.
get_available_testsets
()
}
########################################
# All tasks
########################################
TASK_REGISTRY
=
{
TASK_REGISTRY
=
{
# GLUE
# GLUE
"cola"
:
glue
.
CoLA
,
"cola"
:
glue
.
CoLA
,
...
@@ -90,12 +122,13 @@ TASK_REGISTRY = {
...
@@ -90,12 +122,13 @@ TASK_REGISTRY = {
"arithmetic_5ds"
:
arithmetic
.
Arithmetic5DMinus
,
"arithmetic_5ds"
:
arithmetic
.
Arithmetic5DMinus
,
"arithmetic_2dm"
:
arithmetic
.
Arithmetic2DMultiplication
,
"arithmetic_2dm"
:
arithmetic
.
Arithmetic2DMultiplication
,
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. wmt14-fr-en
# e.g. wmt14-fr-en
**
translation
.
create_tasks_from_benchmarks
(
translation
.
selected_benchmarks
)
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
# chef's selection, mostly wmt20
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
}
}
...
...
lm_eval/tasks/translation.py
View file @
1e5194a2
...
@@ -2,6 +2,7 @@ import abc
...
@@ -2,6 +2,7 @@ import abc
import
json
import
json
import
random
import
random
import
os
import
os
from
collections
import
Iterable
from
pprint
import
pprint
from
pprint
import
pprint
import
pycountry
import
pycountry
...
@@ -20,36 +21,9 @@ See sacrebleu.DATASETS for all available datasets. There are a lot!
...
@@ -20,36 +21,9 @@ See sacrebleu.DATASETS for all available datasets. There are a lot!
sacrebleu_datasets
=
sacrebleu
.
DATASETS
sacrebleu_datasets
=
sacrebleu
.
DATASETS
########################################
# Benchmarks one might want to run
########################################
# 6 total
gpt3_benchmarks
=
{
"wmt14"
:
[
'en-fr'
,
'fr-en'
],
# French
"wmt16"
:
[
'en-ro'
,
'ro-en'
,
'de-en'
,
'en-de'
],
# German, Romanian
}
# 14 total
selected_benchmarks
=
{
**
gpt3_benchmarks
,
"wmt20"
:
[
'fr-de'
,
'de-fr'
,
'en-ru'
,
'ru-en'
,
'en-iu'
,
'iu-en'
],
# French, German, Russian, Inuit
"iwslt17"
:
[
'en-ar'
,
'ar-en'
]
# Arabic
}
# 319 total
all_benchmarks
=
{
ts
:
sacrebleu
.
get_langpairs_for_testset
(
ts
)
for
ts
in
sacrebleu
.
get_available_testsets
()
}
available_tests
=
{
"gpt3_tests"
:
gpt3_benchmarks
,
"selected_tests"
:
selected_benchmarks
,
"all_tests"
:
all_benchmarks
}
def
create_tasks_from_benchmarks
(
benchmark_dict
):
def
create_tasks_from_benchmarks
(
benchmark_dict
):
"""Creates a dictionary of tasks from a dict
"""Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] }
:param benchmark_dict: { dataset: [lang_pair, ...]
,
}
:return: {task_name: task}
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
"""
...
@@ -115,9 +89,8 @@ class GeneralTranslationTask(Task):
...
@@ -115,9 +89,8 @@ class GeneralTranslationTask(Task):
return
doc
[
"src"
]
return
doc
[
"src"
]
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
# TODO Note that some exotic tests have multiple ref lines.
# This shows a single target, though there may be multiple targets in a lang test
# How does sacrebleu handle opening these files?
return
doc
[
"ref"
]
if
isinstance
(
doc
[
"ref"
],
str
)
else
doc
[
"ref"
][
0
]
return
doc
[
"ref"
]
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
""" Uses RequestFactory to construct Requests and returns an iterable of
...
@@ -199,6 +172,14 @@ def print_available_tests():
...
@@ -199,6 +172,14 @@ def print_available_tests():
pprint
({
ts
:
sacrebleu
.
get_langpairs_for_testset
(
ts
)
for
ts
in
sacrebleu
.
get_available_testsets
()})
pprint
({
ts
:
sacrebleu
.
get_langpairs_for_testset
(
ts
)
for
ts
in
sacrebleu
.
get_available_testsets
()})
def
print_available_pairs
():
list_of_pairs
=
[
sacrebleu
.
get_langpairs_for_testset
(
ts
)
for
ts
in
sacrebleu
.
get_available_testsets
()]
pairs
=
set
([
item
for
sublist
in
list_of_pairs
for
item
in
sublist
])
pairs
=
sorted
([
"-"
.
join
(
map
(
code_to_language
,
pair
.
split
(
"-"
)))
for
pair
in
pairs
])
pprint
(
pairs
)
print
(
len
(
pairs
))
def
main
():
def
main
():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
# print_available_tests()
...
@@ -213,6 +194,7 @@ def main():
...
@@ -213,6 +194,7 @@ def main():
# Test task dictionary
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
# print(task, task_class())
print_available_pairs
()
pass
pass
...
@@ -220,7 +202,6 @@ if __name__ == "__main__":
...
@@ -220,7 +202,6 @@ if __name__ == "__main__":
main
()
main
()
########################################
########################################
# Don't mind me...!
# Don't mind me...!
########################################
########################################
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment