Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6a336b15
Commit
6a336b15
authored
Dec 28, 2023
by
lintangsutawika
Browse files
use HFEvaluateAdaptor for hf metrics
parent
20c10dfe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
10 deletions
+24
-10
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+16
-2
lm_eval/api/registry.py
lm_eval/api/registry.py
+4
-4
lm_eval/api/task.py
lm_eval/api/task.py
+4
-4
No files found.
lm_eval/api/metrics.py
View file @
6a336b15
...
...
@@ -159,15 +159,29 @@ def acc_mutual_info_fn(items):
return
mean
(
items
)
exact_match
=
evaluate
.
load
(
"exact_match"
)
class
HFEvaluateAdaptor
:
def
__init__
(
self
,
*
metric_args
,
**
kwargs
):
metric_object
=
evaluate
.
load
(
*
metric_args
)
self
.
hf_evaluate_fn
=
partial
(
metric_object
,
**
kwargs
)
def
__call__
(
self
,
items
):
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
return
self
.
hf_evaluate_fn
(
references
=
refs
,
predictions
=
preds
)
exact_match
=
evaluate
.
load
(
"exact_match"
)
@
register_metric
(
metric
=
"exact_match"
,
higher_is_better
=
True
,
output_type
=
"generate_until"
,
)
def
exact_match
_fn
(
**
kwargs
):
def
hf_evaluate
_fn
(
**
kwargs
):
return
exact_match
.
compute
(
**
kwargs
)
...
...
lm_eval/api/registry.py
View file @
6a336b15
import
os
import
evaluate
from
lm_eval.api.model
import
LM
from
lm_eval.api.metrics
import
HFEvaluateAdaptor
import
logging
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
...
...
@@ -115,7 +115,7 @@ def register_metric(
return
decorate
def
get_metric
(
name
,
hf_evaluate_metric
=
False
):
def
get_metric
(
name
,
hf_evaluate_metric
=
False
,
**
kwargs
):
if
not
hf_evaluate_metric
:
if
name
in
METRIC_FUNCTION_REGISTRY
:
...
...
@@ -126,8 +126,8 @@ def get_metric(name, hf_evaluate_metric=False):
)
try
:
metric_object
=
evaluate
.
load
(
name
)
return
metric_object
.
compute
from
lm_eval.metrics
import
HFEvaluateAdaptor
return
HFEvaluateAdaptor
(
name
,
**
kwargs
)
except
Exception
:
eval_logger
.
error
(
f
"
{
name
}
not found in the evaluate library! Please check https://huggingface.co/evaluate-metric"
,
...
...
lm_eval/api/task.py
View file @
6a336b15
...
...
@@ -17,7 +17,6 @@ import numpy as np
from
typing
import
Union
,
List
,
Any
,
Tuple
,
Literal
from
collections.abc
import
Callable
from
functools
import
partial
from
lm_eval
import
utils
from
lm_eval.api
import
samplers
...
...
@@ -588,11 +587,11 @@ class ConfigurableTask(Task):
metric_name
=
metric_name
.
__name__
else
:
metric_fn
=
get_metric
(
metric_name
,
hf_evaluate_metric
metric_name
,
hf_evaluate_metric
,
**
kwargs
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
self
.
_metric_fn_list
[
metric_name
]
=
partial
(
metric_fn
,
**
kwargs
)
if
kwargs
!=
{}
else
metric_fn
self
.
_metric_fn_list
[
metric_name
]
=
metric_fn
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
...
...
@@ -1106,6 +1105,8 @@ class ConfigurableTask(Task):
gold
=
type
(
result
)(
gold
)
for
metric
in
self
.
_metric_fn_list
.
keys
():
result_dict
[
metric
]
=
(
gold
,
result
)
continue
if
self
.
multiple_target
:
# in the case where we have multiple targets,
# return true if any are true
...
...
@@ -1141,7 +1142,6 @@ class ConfigurableTask(Task):
result_score
=
self
.
_metric_fn_list
[
metric
](
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
metric
],
)
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score
=
self
.
_metric_fn_list
[
metric
]([
gold
,
result
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment