Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7ad6bf45
Unverified
Commit
7ad6bf45
authored
Feb 13, 2021
by
Leo Gao
Committed by
GitHub
Feb 13, 2021
Browse files
Merge pull request #146 from EleutherAI/translation
Translation
parents
0601c909
ac47d481
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
665 additions
and
74 deletions
+665
-74
.gitignore
.gitignore
+2
-0
lm_eval/base.py
lm_eval/base.py
+13
-58
lm_eval/metrics.py
lm_eval/metrics.py
+135
-0
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+14
-1
lm_eval/tasks/anli.py
lm_eval/tasks/anli.py
+2
-1
lm_eval/tasks/arc.py
lm_eval/tasks/arc.py
+3
-1
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+2
-1
lm_eval/tasks/common.py
lm_eval/tasks/common.py
+3
-1
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+2
-1
lm_eval/tasks/lambada.py
lm_eval/tasks/lambada.py
+2
-1
lm_eval/tasks/piqa.py
lm_eval/tasks/piqa.py
+2
-1
lm_eval/tasks/pubmedqa.py
lm_eval/tasks/pubmedqa.py
+2
-1
lm_eval/tasks/qa4mre.py
lm_eval/tasks/qa4mre.py
+2
-1
lm_eval/tasks/race.py
lm_eval/tasks/race.py
+2
-1
lm_eval/tasks/sat.py
lm_eval/tasks/sat.py
+2
-1
lm_eval/tasks/sciq.py
lm_eval/tasks/sciq.py
+2
-1
lm_eval/tasks/superglue.py
lm_eval/tasks/superglue.py
+2
-1
lm_eval/tasks/translation.py
lm_eval/tasks/translation.py
+468
-0
lm_eval/tasks/triviaqa.py
lm_eval/tasks/triviaqa.py
+2
-1
lm_eval/tasks/webqs.py
lm_eval/tasks/webqs.py
+3
-1
No files found.
.gitignore
View file @
7ad6bf45
env
env
*.pyc
*.pyc
data/
data/
.idea
lm_cache
\ No newline at end of file
lm_eval/base.py
View file @
7ad6bf45
import
abc
import
abc
import
random
import
random
import
numpy
as
np
import
numpy
as
np
import
sklearn
import
m
ath
from
lm_eval.metrics
import
m
ean
class
LM
(
abc
.
ABC
):
class
LM
(
abc
.
ABC
):
...
@@ -30,6 +30,7 @@ class LM(abc.ABC):
...
@@ -30,6 +30,7 @@ class LM(abc.ABC):
"""
"""
pass
pass
# TODO: Add an optional max length
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
):
"""Generate greedily until a stopping sequence
"""Generate greedily until a stopping sequence
...
@@ -61,6 +62,14 @@ class LM(abc.ABC):
...
@@ -61,6 +62,14 @@ class LM(abc.ABC):
class
Task
(
abc
.
ABC
):
class
Task
(
abc
.
ABC
):
"""A task represents an entire benchmark including its dataset, problems,
answers, and evaluation methods. See BoolQ for a simple example implementation
A `doc` can be any python object which represents one instance of evaluation.
This is usually a dictionary e.g.
{"question": ..., "answer": ...} or
{"question": ..., question, answer)
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
download
()
self
.
download
()
self
.
_training_docs
=
None
self
.
_training_docs
=
None
...
@@ -148,9 +157,9 @@ class Task(abc.ABC):
...
@@ -148,9 +157,9 @@ class Task(abc.ABC):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
:returns: {str: [
float
] -> float}
:returns: {str: [
metric_score
] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metric
score
s
"""
"""
pass
pass
...
@@ -213,60 +222,6 @@ class MultipleChoiceTask(Task):
...
@@ -213,60 +222,6 @@ class MultipleChoiceTask(Task):
}
}
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
def
matthews_corrcoef
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
def
f1_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
fscore
=
sklearn
.
metrics
.
f1_score
(
golds
,
preds
)
return
np
.
max
(
fscore
)
def
acc_all
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
for
doc
,
pred
in
zip
(
docs
,
preds
):
question_id
=
doc
[
"idx"
][
"question"
]
if
question_id
not
in
question_scoring_dict
:
question_scoring_dict
[
question_id
]
=
[]
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[
question_id
].
append
(
gold_label
==
pred
)
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
def
metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
"""Compute max metric between prediction and each ground truth."""
scores_for_ground_truths
=
[]
for
ground_truth
in
ground_truths
:
score
=
metric_fn
(
prediction
,
ground_truth
)
scores_for_ground_truths
.
append
(
score
)
return
max
(
scores_for_ground_truths
)
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
req_ret_lens
=
{
req_ret_lens
=
{
'loglikelihood'
:
2
,
'loglikelihood'
:
2
,
'greedy_until'
:
None
,
'greedy_until'
:
None
,
...
...
lm_eval/metrics.py
0 → 100644
View file @
7ad6bf45
import
math
from
pprint
import
pprint
import
numpy
as
np
import
sacrebleu
import
sklearn
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
def
matthews_corrcoef
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
def
f1_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
fscore
=
sklearn
.
metrics
.
f1_score
(
golds
,
preds
)
return
np
.
max
(
fscore
)
def
acc_all
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
for
doc
,
pred
in
zip
(
docs
,
preds
):
question_id
=
doc
[
"idx"
][
"question"
]
if
question_id
not
in
question_scoring_dict
:
question_scoring_dict
[
question_id
]
=
[]
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[
question_id
].
append
(
gold_label
==
pred
)
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
def
metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
"""Compute max metric between prediction and each ground truth."""
scores_for_ground_truths
=
[]
for
ground_truth
in
ground_truths
:
score
=
metric_fn
(
prediction
,
ground_truth
)
scores_for_ground_truths
.
append
(
score
)
return
max
(
scores_for_ground_truths
)
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
def
bleu
(
items
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
def
chrf
(
items
):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
def
ter
(
items
):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if
not
isinstance
(
refs
,
list
):
refs
=
list
(
refs
)
if
not
isinstance
(
refs
[
0
],
list
):
refs
=
[[
ref
]
for
ref
in
refs
]
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if
not
isinstance
(
preds
,
list
):
preds
=
list
(
preds
)
if
isinstance
(
preds
[
0
],
list
):
assert
len
(
preds
[
0
])
==
1
,
f
"Pred must be a str, was
{
preds
[
0
]
}
"
preds
=
[
pred
[
0
]
for
pred
in
preds
]
return
refs
,
preds
lm_eval/tasks/__init__.py
View file @
7ad6bf45
from
pprint
import
pprint
from
.
import
superglue
from
.
import
superglue
from
.
import
glue
from
.
import
glue
from
.
import
arc
from
.
import
arc
...
@@ -21,6 +23,7 @@ from . import pubmedqa
...
@@ -21,6 +23,7 @@ from . import pubmedqa
from
.
import
sciq
from
.
import
sciq
from
.
import
webqs
from
.
import
webqs
from
.
import
qa4mre
from
.
import
qa4mre
from
.
import
translation
from
.
import
headqa
from
.
import
headqa
from
.
import
mathqa
from
.
import
mathqa
...
@@ -88,6 +91,11 @@ TASK_REGISTRY = {
...
@@ -88,6 +91,11 @@ TASK_REGISTRY = {
"arithmetic_2dm"
:
arithmetic
.
Arithmetic2DMultiplication
,
"arithmetic_2dm"
:
arithmetic
.
Arithmetic2DMultiplication
,
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. wmt14-fr-en
**
translation
.
create_tasks_from_benchmarks
(
translation
.
selected_benchmarks
)
}
}
...
@@ -95,7 +103,12 @@ ALL_TASKS = sorted(list(TASK_REGISTRY))
...
@@ -95,7 +103,12 @@ ALL_TASKS = sorted(list(TASK_REGISTRY))
def
get_task
(
task_name
):
def
get_task
(
task_name
):
try
:
return
TASK_REGISTRY
[
task_name
]
return
TASK_REGISTRY
[
task_name
]
except
KeyError
as
e
:
print
(
"Available tasks:"
)
pprint
(
TASK_REGISTRY
)
raise
KeyError
(
f
"Missing task
{
task_name
}
"
)
def
get_task_dict
(
task_name_list
):
def
get_task_dict
(
task_name_list
):
...
...
lm_eval/tasks/anli.py
View file @
7ad6bf45
import
numpy
as
np
import
numpy
as
np
from
lm_eval.base
import
rf
,
mean
from
lm_eval.base
import
rf
from
..metrics
import
mean
from
.
common
import
HFTask
from
.
common
import
HFTask
class
ANLIBase
(
HFTask
):
class
ANLIBase
(
HFTask
):
...
...
lm_eval/tasks/arc.py
View file @
7ad6bf45
import
numpy
as
np
from
lm_eval.base
import
MultipleChoiceTask
from
lm_eval.base
import
MultipleChoiceTask
from
.common
import
HFTask
from
..metrics
import
mean
from
.
common
import
HFTask
class
ARCEasy
(
HFTask
,
MultipleChoiceTask
):
class
ARCEasy
(
HFTask
,
MultipleChoiceTask
):
...
...
lm_eval/tasks/arithmetic.py
View file @
7ad6bf45
...
@@ -2,7 +2,8 @@ import abc
...
@@ -2,7 +2,8 @@ import abc
import
json
import
json
import
os
import
os
from
collections
import
namedtuple
from
collections
import
namedtuple
from
lm_eval.base
import
Task
,
mean
,
rf
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
from
best_download
import
download_file
from
best_download
import
download_file
ArithmeticDoc
=
namedtuple
(
'ArithmeticDoc'
,
[
'context'
,
'completion'
])
ArithmeticDoc
=
namedtuple
(
'ArithmeticDoc'
,
[
'context'
,
'completion'
])
...
...
lm_eval/tasks/common.py
View file @
7ad6bf45
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
import
lm_eval.metrics
from
..base
import
Task
from
..base
import
Task
...
@@ -44,7 +46,7 @@ class HFTask(Task):
...
@@ -44,7 +46,7 @@ class HFTask(Task):
def
simple_accuracy_metric
(
preds
,
golds
):
def
simple_accuracy_metric
(
preds
,
golds
):
acc
=
float
(
(
np
.
array
(
preds
)
==
np
.
array
(
golds
))
.
mean
())
acc
=
float
(
lm_eval
.
metrics
.
mean
())
return
{
return
{
"major"
:
acc
,
"major"
:
acc
,
"minor"
:
{
"acc"
:
acc
},
"minor"
:
{
"acc"
:
acc
},
...
...
lm_eval/tasks/glue.py
View file @
7ad6bf45
import
numpy
as
np
import
numpy
as
np
from
lm_eval.base
import
rf
,
mean
,
f1_score
,
matthews_corrcoef
from
lm_eval.base
import
rf
from
..metrics
import
mean
,
matthews_corrcoef
,
f1_score
from
scipy.stats
import
pearsonr
,
spearmanr
from
scipy.stats
import
pearsonr
,
spearmanr
from
tqdm
import
auto
as
tqdm_lib
from
tqdm
import
auto
as
tqdm_lib
from
.
common
import
HFTask
,
yesno
from
.
common
import
HFTask
,
yesno
...
...
lm_eval/tasks/lambada.py
View file @
7ad6bf45
from
lm_eval.base
import
Task
,
rf
,
mean
,
perplexity
from
lm_eval.base
import
Task
,
rf
from
lm_eval.metrics
import
mean
,
perplexity
from
lm_eval.utils
import
sh
from
lm_eval.utils
import
sh
import
json
import
json
import
math
import
math
...
...
lm_eval/tasks/piqa.py
View file @
7ad6bf45
import
numpy
as
np
import
numpy
as
np
from
lm_eval.base
import
rf
,
mean
from
lm_eval.base
import
rf
from
..metrics
import
mean
from
.
common
import
HFTask
from
.
common
import
HFTask
...
...
lm_eval/tasks/pubmedqa.py
View file @
7ad6bf45
...
@@ -2,7 +2,8 @@ import numpy as np
...
@@ -2,7 +2,8 @@ import numpy as np
import
json
import
json
import
random
import
random
from
.common
import
HFTask
from
.common
import
HFTask
from
lm_eval.base
import
rf
,
mean
from
lm_eval.base
import
rf
from
..metrics
import
mean
class
Pubmed_QA
(
HFTask
):
class
Pubmed_QA
(
HFTask
):
...
...
lm_eval/tasks/qa4mre.py
View file @
7ad6bf45
import
os
import
os
import
numpy
as
np
import
numpy
as
np
from
best_download
import
download_file
from
best_download
import
download_file
from
lm_eval.base
import
MultipleChoiceTask
,
rf
,
mean
from
lm_eval.base
import
MultipleChoiceTask
,
rf
from
lm_eval.metrics
import
mean
import
xml.etree.ElementTree
as
ET
import
xml.etree.ElementTree
as
ET
import
random
import
random
...
...
lm_eval/tasks/race.py
View file @
7ad6bf45
import
collections
import
collections
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
from
lm_eval.base
import
rf
,
mean
from
lm_eval.base
import
rf
from
..metrics
import
mean
from
.
common
import
HFTask
from
.
common
import
HFTask
import
os
import
os
...
...
lm_eval/tasks/sat.py
View file @
7ad6bf45
import
json
import
json
import
random
import
random
import
os
import
os
from
lm_eval.base
import
MultipleChoiceTask
,
rf
,
mean
from
lm_eval.base
import
MultipleChoiceTask
,
rf
from
..metrics
import
mean
from
tqdm
import
auto
as
tqdm_lib
from
tqdm
import
auto
as
tqdm_lib
from
.
common
import
simple_accuracy_metric
from
.
common
import
simple_accuracy_metric
import
numpy
as
np
import
numpy
as
np
...
...
lm_eval/tasks/sciq.py
View file @
7ad6bf45
import
os
import
os
import
json
import
json
from
..utils
import
sh
from
..utils
import
sh
from
lm_eval.base
import
MultipleChoiceTask
,
rf
,
mean
from
lm_eval.base
import
MultipleChoiceTask
,
rf
from
..metrics
import
mean
import
zipfile
import
zipfile
from
best_download
import
download_file
from
best_download
import
download_file
...
...
lm_eval/tasks/superglue.py
View file @
7ad6bf45
...
@@ -5,7 +5,8 @@ To-do:
...
@@ -5,7 +5,8 @@ To-do:
"""
"""
import
numpy
as
np
import
numpy
as
np
from
.
common
import
HFTask
,
yesno
from
.
common
import
HFTask
,
yesno
from
lm_eval.base
import
rf
,
mean
,
acc_all
,
metric_max_over_ground_truths
from
lm_eval.base
import
rf
from
..metrics
import
mean
,
acc_all
,
metric_max_over_ground_truths
import
sklearn
import
sklearn
import
transformers.data.metrics.squad_metrics
as
squad_metrics
import
transformers.data.metrics.squad_metrics
as
squad_metrics
from
..utils
import
general_detokenize
from
..utils
import
general_detokenize
...
...
lm_eval/tasks/translation.py
0 → 100644
View file @
7ad6bf45
import
abc
import
json
import
random
import
os
from
pprint
import
pprint
import
pycountry
from
sacrebleu
import
sacrebleu
import
logging
from
lm_eval
import
metrics
from
lm_eval.base
import
Task
,
rf
"""
This file implements translation tasks using datasets from WMT conferences, provided by sacrebleu.
Traditionally they are evaluated with BLEU scores. TER and CHRF are other options.
See sacrebleu.DATASETS for all available datasets. There are a lot!
"""
sacrebleu_datasets
=
sacrebleu
.
DATASETS
########################################
# Benchmarks one might want to run
########################################
# 6 total
gpt3_benchmarks
=
{
"wmt14"
:
[
'en-fr'
,
'fr-en'
],
# French
"wmt16"
:
[
'en-ro'
,
'ro-en'
,
'de-en'
,
'en-de'
],
# German, Romanian
}
# 14 total
selected_benchmarks
=
{
**
gpt3_benchmarks
,
"wmt20"
:
[
'fr-de'
,
'de-fr'
,
'en-ru'
,
'ru-en'
,
'en-iu'
,
'iu-en'
],
# French, German, Russian, Inuit
"iwslt17"
:
[
'en-ar'
,
'ar-en'
]
# Arabic
}
# 319 total
all_benchmarks
=
{
ts
:
sacrebleu
.
get_langpairs_for_testset
(
ts
)
for
ts
in
sacrebleu
.
get_available_testsets
()
}
available_tests
=
{
"gpt3_tests"
:
gpt3_benchmarks
,
"selected_tests"
:
selected_benchmarks
,
"all_tests"
:
all_benchmarks
}
def
create_tasks_from_benchmarks
(
benchmark_dict
):
"""Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] }
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
return
{
f
"
{
dataset
}
-
{
language_pair
}
"
:
create_translation_task
(
dataset
,
language_pair
)
for
dataset
,
language_pairs
in
benchmark_dict
.
items
()
for
language_pair
in
language_pairs
}
########################################
# Tasks
########################################
def
create_translation_task
(
dataset
,
language_pair
):
class
TranslationTask
(
GeneralTranslationTask
):
def
__init__
(
self
):
super
().
__init__
(
dataset
,
language_pair
)
return
TranslationTask
class
GeneralTranslationTask
(
Task
):
# e.g. ("wmt14", "fr-en")
def
__init__
(
self
,
sacrebleu_dataset
,
sacrebleu_language_pair
=
None
):
self
.
sacrebleu_dataset
=
sacrebleu_dataset
self
.
sacrebleu_language_pair
=
sacrebleu_language_pair
self
.
src_file
=
self
.
ref_file
=
self
.
src_data
=
self
.
ref_data
=
None
super
().
__init__
()
def
download
(
self
):
# This caches in the users home dir automatically
self
.
src_file
,
self
.
ref_file
=
\
sacrebleu
.
download_test_set
(
self
.
sacrebleu_dataset
,
self
.
sacrebleu_language_pair
)
self
.
src_data
,
self
.
ref_data
=
[
[
line
.
rstrip
()
for
line
in
sacrebleu
.
smart_open
(
file
)]
for
file
in
(
self
.
src_file
,
self
.
ref_file
)
]
def
has_training_docs
(
self
):
"""Whether the task has a training set"""
# TODO In the future we could be more discerning. Some more recent tests have train and dev sets
return
False
def
has_validation_docs
(
self
):
"""Whether the task has a validation set"""
return
False
def
has_test_docs
(
self
):
"""Whether the task has a test set"""
return
True
def
test_docs
(
self
):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return
[{
"src"
:
src
,
"ref"
:
ref
}
for
src
,
ref
in
zip
(
self
.
src_data
,
self
.
ref_data
)]
def
doc_to_text
(
self
,
doc
):
return
doc
[
"src"
]
def
doc_to_target
(
self
,
doc
):
# TODO Note that some exotic tests have multiple ref lines.
# How does sacrebleu handle opening these files?
return
doc
[
"ref"
]
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])
def
process_results
(
self
,
doc
,
results
):
# These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method
ref_pred
=
(
doc
[
"ref"
],
results
)
return
{
"bleu"
:
ref_pred
,
"chrf"
:
ref_pred
,
"ter"
:
ref_pred
,
}
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"bleu"
:
metrics
.
bleu
,
"chrf"
:
metrics
.
chrf
,
"ter"
:
metrics
.
ter
,
}
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"bleu"
:
True
,
"chrf"
:
True
,
"ter"
:
False
,
}
def
fewshot_description
(
self
):
language_codes
=
self
.
sacrebleu_language_pair
.
split
(
"-"
)
src_lang
=
code_to_language
(
language_codes
[
0
])
tar_lang
=
code_to_language
(
language_codes
[
1
])
return
f
"Translate these
{
src_lang
}
phrases to
{
tar_lang
}
."
# TODO This should be something like
# French: {src_line}
# English: {ref_line}
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
):
return
""
def
__str__
(
self
):
language_codes
=
self
.
sacrebleu_language_pair
.
split
(
"-"
)
src_lang
=
code_to_language
(
language_codes
[
0
])
tar_lang
=
code_to_language
(
language_codes
[
1
])
return
f
"
{
self
.
sacrebleu_dataset
.
upper
()
}
{
src_lang
}
to
{
tar_lang
}
Task"
########################################
# Util
########################################
def
code_to_language
(
code
):
# key is alpha_2 or alpha_3 depending on the code length
language_tuple
=
pycountry
.
languages
.
get
(
**
{
f
"alpha_
{
len
(
code
)
}
"
:
code
})
return
language_tuple
.
name
def
print_available_tests
():
pprint
({
ts
:
sacrebleu
.
get_langpairs_for_testset
(
ts
)
for
ts
in
sacrebleu
.
get_available_testsets
()})
def
main
():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
# sacrebleu.print_test_set("wmt14", "fr-en", "src")
# # Print number of benchmarks
# print(sum([
# len(sacrebleu.get_langpairs_for_testset(ts))
# for ts in sacrebleu.get_available_testsets()
# ]))
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
pass
if
__name__
==
"__main__"
:
main
()
########################################
# Don't mind me...!
########################################
# Available tests as of 2020/02/11
"""
{'iwslt17': ['en-fr',
'fr-en',
'en-de',
'de-en',
'en-zh',
'zh-en',
'en-ar',
'ar-en',
'en-ja',
'ja-en',
'en-ko',
'ko-en'],
'iwslt17/dev2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2011': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2012': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2013': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2014': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2015': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2016': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'mtnt1.1/test': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/train': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/valid': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt2019': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'multi30k/2016': ['en-fr', 'en-de', 'en-cs'],
'multi30k/2017': ['en-fr', 'en-de'],
'multi30k/2018': ['en-fr', 'en-de'],
'wmt08': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu'],
'wmt08/europarl': ['de-en', 'en-de', 'es-en', 'en-es', 'fr-en', 'en-fr'],
'wmt08/nc': ['cs-en', 'en-cs'],
'wmt09': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu',
'it-en',
'en-it'],
'wmt10': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt11': ['cs-en',
'en-cs',
'de-en',
'en-de',
'fr-en',
'en-fr',
'es-en',
'en-es'],
'wmt12': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt13': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'ru-en',
'en-ru'],
'wmt14': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt14/full': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt15': ['en-fr',
'fr-en',
'cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ru',
'fi-en',
'ru-en'],
'wmt16': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ro',
'en-ru',
'en-tr',
'fi-en',
'ro-en',
'ru-en',
'tr-en'],
'wmt16/B': ['en-fi'],
'wmt16/dev': ['en-ro', 'en-tr', 'ro-en', 'tr-en'],
'wmt16/tworefs': ['en-fi'],
'wmt17': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-lv',
'en-ru',
'en-tr',
'en-zh',
'fi-en',
'lv-en',
'ru-en',
'tr-en',
'zh-en'],
'wmt17/B': ['en-fi'],
'wmt17/dev': ['en-lv', 'en-zh', 'lv-en', 'zh-en'],
'wmt17/improved': ['en-zh', 'zh-en'],
'wmt17/ms': ['zh-en'],
'wmt17/tworefs': ['en-fi'],
'wmt18': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt18/dev': ['et-en', 'en-et'],
'wmt18/test-ts': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt19': ['cs-de',
'de-cs',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-fi',
'en-gu',
'en-kk',
'en-lt',
'en-ru',
'en-zh',
'fi-en',
'fr-de',
'gu-en',
'kk-en',
'lt-en',
'ru-en',
'zh-en'],
'wmt19/dev': ['lt-en', 'en-lt', 'gu-en', 'en-gu', 'kk-en', 'en-kk'],
'wmt19/google/ar': ['en-de'],
'wmt19/google/arp': ['en-de'],
'wmt19/google/hqall': ['en-de'],
'wmt19/google/hqp': ['en-de'],
'wmt19/google/hqr': ['en-de'],
'wmt19/google/wmtp': ['en-de'],
'wmt20': ['cs-en',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-iu',
'en-ja',
'en-km',
'en-pl',
'en-ps',
'en-ru',
'en-ta',
'en-zh',
'fr-de',
'iu-en',
'ja-en',
'km-en',
'pl-en',
'ps-en',
'ru-en',
'ta-en',
'zh-en'],
'wmt20/dev': ['iu-en',
'en-iu',
'ja-en',
'en-ja',
'pl-en',
'en-pl',
'ta-en',
'en-ta'],
'wmt20/robust/set1': ['en-ja', 'en-de'],
'wmt20/robust/set2': ['en-ja', 'ja-en'],
'wmt20/robust/set3': ['de-en'],
'wmt20/tworefs': ['de-en', 'en-de', 'en-zh', 'ru-en', 'zh-en']}
"""
\ No newline at end of file
lm_eval/tasks/triviaqa.py
View file @
7ad6bf45
import
os
import
os
import
json
import
json
import
random
import
random
from
lm_eval.base
import
Task
,
mean
,
rf
from
lm_eval.base
import
Task
,
rf
from
..metrics
import
mean
from
..utils
import
sh
from
..utils
import
sh
class
TriviaQA
(
Task
):
class
TriviaQA
(
Task
):
...
...
lm_eval/tasks/webqs.py
View file @
7ad6bf45
from
.
common
import
HFTask
from
.
common
import
HFTask
from
lm_eval.base
import
mean
,
rf
from
lm_eval.base
import
rf
from
..metrics
import
mean
class
WebQs
(
HFTask
):
class
WebQs
(
HFTask
):
DATASET_PATH
=
"web_questions"
DATASET_PATH
=
"web_questions"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment