Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
2a9da9fb
Commit
2a9da9fb
authored
Apr 24, 2023
by
haileyschoelkopf
Committed by
Hailey Schoelkopf
Apr 24, 2023
Browse files
add metric + agg registries
parent
460584ca
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
45 deletions
+91
-45
lm_eval/api/__init__.py
lm_eval/api/__init__.py
+0
-15
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+68
-0
lm_eval/api/task.py
lm_eval/api/task.py
+22
-19
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+1
-11
No files found.
lm_eval/api/__init__.py
View file @
2a9da9fb
from
.
import
metrics
METRIC_REGISTRY
=
{
"matthews_corrcoef"
:
metrics
.
matthews_corrcoef
,
"f1_score"
:
metrics
.
f1_score
,
"perplexity"
:
metrics
.
perplexity
,
"bleu"
:
metrics
.
bleu
,
"chrf"
:
metrics
.
chrf
,
"ter"
:
metrics
.
ter
,
}
AGGREGATION_REGISTRY
=
{
"mean"
:
metrics
.
mean
,
"median"
:
metrics
.
median
}
\ No newline at end of file
lm_eval/api/metrics.py
View file @
2a9da9fb
...
@@ -6,7 +6,67 @@ import sacrebleu
...
@@ -6,7 +6,67 @@ import sacrebleu
import
sklearn.metrics
import
sklearn.metrics
import
random
import
random
import
evaluate
AGGREGATION_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
def
register_metric
(
name
):
# TODO: do we want to enforce a certain interface to registered metrics?
def
decorate
(
fn
):
assert
(
name
not
in
METRIC_REGISTRY
),
f
"metric named '
{
name
}
' conflicts with existing registered metric!"
METRIC_REGISTRY
[
name
]
=
fn
return
fn
return
decorate
def
get_metric
(
name
):
try
:
return
METRIC_REGISTRY
[
name
]
except
KeyError
:
# TODO: change this print to logging?
print
(
f
"Could not find registered metric '
{
name
}
' in lm-eval,
\
searching in HF Evaluate library..."
)
try
:
metric_object
=
evaluate
.
load
(
name
)
return
metric_object
.
compute
except
:
raise
Warning
(
"{} not found in the evaluate library!"
.
format
(
name
),
"Please check https://huggingface.co/evaluate-metric"
,
)
def
register_aggregation
(
name
):
def
decorate
(
fn
):
assert
(
name
not
in
AGGREGATION_REGISTRY
),
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY
[
name
]
=
fn
return
fn
return
decorate
def
get_aggregation
(
name
):
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
raise
Warning
(
"{} not a registered aggregation metric!"
.
format
(
name
),
)
@
register_aggregation
(
"mean"
)
def
mean
(
arr
):
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
return
sum
(
arr
)
/
len
(
arr
)
...
@@ -25,10 +85,12 @@ def mean_stderr(arr):
...
@@ -25,10 +85,12 @@ def mean_stderr(arr):
return
sample_stddev
(
arr
)
/
math
.
sqrt
(
len
(
arr
))
return
sample_stddev
(
arr
)
/
math
.
sqrt
(
len
(
arr
))
@
register_aggregation
(
"median"
)
def
median
(
arr
):
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
return
arr
[
len
(
arr
)
//
2
]
@
register_metric
(
"matthews_corrcoef"
)
def
matthews_corrcoef
(
items
):
def
matthews_corrcoef
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
golds
=
unzipped_list
[
0
]
...
@@ -36,6 +98,7 @@ def matthews_corrcoef(items):
...
@@ -36,6 +98,7 @@ def matthews_corrcoef(items):
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
@
register_metric
(
"f1_score"
)
def
f1_score
(
items
):
def
f1_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
golds
=
unzipped_list
[
0
]
...
@@ -91,6 +154,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
...
@@ -91,6 +154,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
return
max
(
scores_for_ground_truths
)
@
register_metric
(
"perplexity"
)
def
perplexity
(
items
):
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
return
math
.
exp
(
-
mean
(
items
))
...
@@ -100,6 +164,7 @@ def weighted_mean(items):
...
@@ -100,6 +164,7 @@ def weighted_mean(items):
return
sum
(
a
)
/
sum
(
b
)
return
sum
(
a
)
/
sum
(
b
)
@
register_metric
(
"weighted_perplexity"
)
def
weighted_perplexity
(
items
):
def
weighted_perplexity
(
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
return
math
.
exp
(
-
weighted_mean
(
items
))
...
@@ -108,6 +173,7 @@ def bits_per_byte(items):
...
@@ -108,6 +173,7 @@ def bits_per_byte(items):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
@
register_metric
(
"bleu"
)
def
bleu
(
items
):
def
bleu
(
items
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
for evaluating a generated sentence to a reference sentence. It counts matching
...
@@ -125,6 +191,7 @@ def bleu(items):
...
@@ -125,6 +191,7 @@ def bleu(items):
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
@
register_metric
(
"chrf"
)
def
chrf
(
items
):
def
chrf
(
items
):
"""chrF++ is a tool for automatic evaluation of machine translation output
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
based on character n-gram precision and recall enhanced with word n-grams.
...
@@ -139,6 +206,7 @@ def chrf(items):
...
@@ -139,6 +206,7 @@ def chrf(items):
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
@
register_metric
(
"ter"
)
def
ter
(
items
):
def
ter
(
items
):
"""Translation Error Rate is an error metric for machine translation that
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
measures the number of edits required to change a system output into one
...
...
lm_eval/api/task.py
View file @
2a9da9fb
...
@@ -11,9 +11,8 @@ import numpy as np
...
@@ -11,9 +11,8 @@ import numpy as np
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
lm_eval.api
import
METRIC_REGISTRY
,
AGGREGATION_REGISTRY
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
mean
,
weighted_perplexity
,
weighted_mean
,
bits_per_byte
from
lm_eval.api.metrics
import
get_metric
,
get_aggregation
,
mean
,
weighted_perplexity
,
bits_per_byte
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
...
@@ -32,8 +31,8 @@ class TaskConfig(dict):
...
@@ -32,8 +31,8 @@ class TaskConfig(dict):
fewshot_split
:
str
=
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
fewshot_split
:
str
=
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases
:
str
=
""
template_aliases
:
str
=
""
doc_to_text
:
str
=
None
doc_to_text
:
str
=
""
doc_to_target
:
str
=
None
doc_to_target
:
str
=
""
# aggregation: dict = None # TODO: remove, I think these 2 are obsolete w/ current metric_list impl.
# aggregation: dict = None # TODO: remove, I think these 2 are obsolete w/ current metric_list impl.
# higher_is_better: dict = None
# higher_is_better: dict = None
...
@@ -111,7 +110,7 @@ class Task(abc.ABC):
...
@@ -111,7 +110,7 @@ class Task(abc.ABC):
self
.
_fewshot_docs
=
None
self
.
_fewshot_docs
=
None
self
.
_instances
=
None
self
.
_instances
=
None
self
.
_config
=
TaskConfig
(
**
config
)
if
config
else
{}
self
.
_config
=
TaskConfig
(
**
config
)
if
config
else
TaskConfig
()
if
not
hasattr
(
self
,
"_filters"
):
if
not
hasattr
(
self
,
"_filters"
):
self
.
_filters
=
[]
self
.
_filters
=
[]
...
@@ -392,20 +391,23 @@ class ConfigurableTask(Task):
...
@@ -392,20 +391,23 @@ class ConfigurableTask(Task):
self
.
_higher_is_better
=
{}
self
.
_higher_is_better
=
{}
for
(
metric_name
,
aggregation
,
higher_is_better
)
in
self
.
_config
.
metric_list
:
for
(
metric_name
,
aggregation
,
higher_is_better
)
in
self
.
_config
.
metric_list
:
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
aggregation
]
self
.
_aggregation_list
[
metric_name
]
=
get_aggregation
(
aggregation
)
self
.
_higher_is_better
[
metric_name
]
=
higher_is_better
self
.
_higher_is_better
[
metric_name
]
=
higher_is_better
if
metric_name
in
METRIC_REGISTRY
.
keys
():
self
.
_metric_list
[
metric_name
]
=
get_metric
(
metric_name
)
self
.
_metric_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
else
:
try
:
# if metric_name in METRIC_REGISTRY.keys():
metric_object
=
evaluate
.
load
(
metric_name
)
# self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self
.
_metric_list
[
metric_name
]
=
metric_object
# else:
except
Exception
as
ex
:
# try:
raise
Warning
(
# metric_object = evaluate.load(metric_name)
"{} not found in the evaluate library!"
.
format
(
metric_name
),
# self._metric_list[metric_name] = metric_object
"Please check https://huggingface.co/evaluate-metric"
,
# except Exception as ex:
)
# raise Warning(
# "{} not found in the evaluate library!".format(metric_name),
# "Please check https://huggingface.co/evaluate-metric",
# )
self
.
download
(
data_dir
,
cache_dir
,
download_mode
)
self
.
download
(
data_dir
,
cache_dir
,
download_mode
)
self
.
_training_docs
=
None
self
.
_training_docs
=
None
...
@@ -478,7 +480,7 @@ class ConfigurableTask(Task):
...
@@ -478,7 +480,7 @@ class ConfigurableTask(Task):
result_dict
=
{}
result_dict
=
{}
for
key
,
result
in
zip
(
self
.
_metric_list
.
keys
(),
results
):
for
key
,
result
in
zip
(
self
.
_metric_list
.
keys
(),
results
):
_dict
=
self
.
_metric_list
[
key
]
.
compute
(
_dict
=
self
.
_metric_list
[
key
](
references
=
[
gold
],
references
=
[
gold
],
predictions
=
[
result
],
predictions
=
[
result
],
)
)
...
@@ -493,7 +495,7 @@ class ConfigurableTask(Task):
...
@@ -493,7 +495,7 @@ class ConfigurableTask(Task):
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
self
.
_higher_is_better
_list
return
self
.
_higher_is_better
class
MultipleChoiceTask
(
Task
):
class
MultipleChoiceTask
(
Task
):
...
@@ -659,6 +661,7 @@ def get_task_name_from_object(task_object):
...
@@ -659,6 +661,7 @@ def get_task_name_from_object(task_object):
if
class_
is
task_object
:
if
class_
is
task_object
:
return
name
return
name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return
(
return
(
task_object
.
EVAL_HARNESS_NAME
task_object
.
EVAL_HARNESS_NAME
...
...
lm_eval/models/__init__.py
View file @
2a9da9fb
...
@@ -5,14 +5,4 @@ from . import gpt3
...
@@ -5,14 +5,4 @@ from . import gpt3
from
.
import
textsynth
from
.
import
textsynth
from
.
import
dummy
from
.
import
dummy
# MODEL_REGISTRY = {}
# TODO: implement __all__
# MODEL_REGISTRY = {
# "hf-causal": gpt2.HFLM,
# "openai": gpt3.GPT3LM,
# "textsynth": textsynth.TextSynthLM,
# "dummy": dummy.DummyLM,
# }
# def get_model(model_name):
# return MODEL_REGISTRY[model_name]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment