Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9d6bc929
Commit
9d6bc929
authored
Dec 28, 2023
by
lintangsutawika
Browse files
aggregation to compute_metric
parent
4d49dd03
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
57 deletions
+37
-57
lm_eval/api/task.py
lm_eval/api/task.py
+37
-57
No files found.
lm_eval/api/task.py
View file @
9d6bc929
...
@@ -29,16 +29,11 @@ from lm_eval.api.metrics import (
...
@@ -29,16 +29,11 @@ from lm_eval.api.metrics import (
mean
,
mean
,
weighted_perplexity
,
weighted_perplexity
,
bits_per_byte
,
bits_per_byte
,
metric_max_over_ground_truths
,
)
)
from
lm_eval.api.registry
import
(
from
lm_eval.api.registry
import
(
get_metric
,
get_metric
,
get_aggregation
,
get_metric_aggregation
,
is_higher_better
,
is_higher_better
,
DEFAULT_METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
OUTPUT_TYPE_REGISTRY
,
AGGREGATION_REGISTRY
,
)
)
ALL_OUTPUT_TYPES
=
[
ALL_OUTPUT_TYPES
=
[
...
@@ -415,7 +410,7 @@ class Task(abc.ABC):
...
@@ -415,7 +410,7 @@ class Task(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
aggregation
(
self
):
def
compute_metric
(
self
):
"""
"""
:returns: {str: [metric_score] -> float}
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
...
@@ -569,13 +564,6 @@ class ConfigurableTask(Task):
...
@@ -569,13 +564,6 @@ class ConfigurableTask(Task):
metric
=
get_metric
(
metric_name
)
metric
=
get_metric
(
metric_name
)
self
.
_metric_fn_list
[
metric_name
]
=
metric
self
.
_metric_fn_list
[
metric_name
]
=
metric
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
self
.
_aggregation_list
[
metric_name
]
=
metric
.
aggregation
# try:
# self._aggregation_list[metric_name] = metric.aggregation
# except:
# self._aggregation_list[metric_name] = get_metric_aggregation(
# metric_name
# )
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
else
:
else
:
for
metric_config
in
self
.
config
.
metric_list
:
for
metric_config
in
self
.
config
.
metric_list
:
...
@@ -606,36 +594,6 @@ class ConfigurableTask(Task):
...
@@ -606,36 +594,6 @@ class ConfigurableTask(Task):
)
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
# if "aggregation" in metric_config:
# agg_name = metric_config["aggregation"]
# if type(agg_name) == str:
# self._aggregation_list[metric_name] = get_aggregation(agg_name)
# elif callable(agg_name):
# self._aggregation_list[metric_name] = metric_config[
# "aggregation"
# ]
# else:
# INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
# metric_agg = get_metric_aggregation(metric_name)
# eval_logger.warning(
# f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
# f"using default "
# f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
# )
# self._aggregation_list[metric_name] = metric_agg
# if "higher_is_better" in metric_config:
# self._higher_is_better[metric_name] = metric_config[
# "higher_is_better"
# ]
# else:
# eval_logger.warning(
# f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. "
# f"using default "
# f"higher_is_better={is_higher_better(metric_name)}"
# )
# self._higher_is_better[metric_name] = is_higher_better(metric_name)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
self
.
_training_docs
=
None
self
.
_fewshot_docs
=
None
self
.
_fewshot_docs
=
None
...
@@ -1023,19 +981,43 @@ class ConfigurableTask(Task):
...
@@ -1023,19 +981,43 @@ class ConfigurableTask(Task):
)
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
# Process results returns 1 of X things per doc/results
# 1. A score
# 2. Components to be processed later to obtained a score. such as gold and prediction
if
callable
(
self
.
config
.
process_results
):
if
callable
(
self
.
config
.
process_results
):
return
self
.
config
.
process_results
(
doc
,
results
)
return
self
.
config
.
process_results
(
doc
,
results
)
result_dict
=
{}
result_dict
=
{}
use_metric
=
list
(
self
.
_metric_fn_list
.
keys
())
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
results
=
results
[
0
]
results
=
results
[
0
]
ll
,
is_greedy
=
results
ll
,
is_greedy
=
results
return
ll
,
is_greedy
return
{
**
({
"perplexity"
:
ll
}
if
"perplexity"
in
use_metric
else
{}),
**
({
"acc"
:
int
(
is_greedy
)}
if
"acc"
in
use_metric
else
{}),
}
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
(
loglikelihood
,)
=
results
(
loglikelihood
,)
=
results
_words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
_words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
_bytes
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
_bytes
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
return
loglikelihood
,
_words
,
_bytes
return
{
**
(
{
"word_perplexity"
:
(
loglikelihood
,
_words
)}
if
"word_perplexity"
in
use_metric
else
{}
),
**
(
{
"byte_perplexity"
:
(
loglikelihood
,
_bytes
)}
if
"byte_perplexity"
in
use_metric
else
{}
),
**
(
{
"bits_per_byte"
:
(
loglikelihood
,
_bytes
)}
if
"bits_per_byte"
in
use_metric
else
{}
),
}
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
lls
,
is_greedy
=
zip
(
*
results
)
lls
,
is_greedy
=
zip
(
*
results
)
...
@@ -1063,14 +1045,14 @@ class ConfigurableTask(Task):
...
@@ -1063,14 +1045,14 @@ class ConfigurableTask(Task):
gold
=
self
.
doc_to_target
(
doc
)
gold
=
self
.
doc_to_target
(
doc
)
gold_index_error
=
False
gold_index_error
=
False
if
typ
e
(
gold
)
is
list
:
if
isinstanc
e
(
gold
,
list
)
:
gold
=
[
i
if
i
<
len
(
choices
)
else
-
100
for
i
in
gold
]
gold
=
[
i
if
i
<
len
(
choices
)
else
-
100
for
i
in
gold
]
if
-
100
in
gold
:
if
-
100
in
gold
:
gold_index_error
=
True
gold_index_error
=
True
else
:
else
:
if
typ
e
(
gold
)
is
int
:
if
isinstanc
e
(
gold
,
int
)
:
gold
=
gold
if
gold
<
len
(
choices
)
else
-
100
gold
=
gold
if
gold
<
len
(
choices
)
else
-
100
elif
typ
e
(
gold
)
is
str
:
elif
isinstanc
e
(
gold
,
str
)
:
gold
=
choices
.
index
(
gold
)
if
gold
in
choices
else
-
100
gold
=
choices
.
index
(
gold
)
if
gold
in
choices
else
-
100
if
gold
==
-
100
:
if
gold
==
-
100
:
...
@@ -1092,12 +1074,13 @@ class ConfigurableTask(Task):
...
@@ -1092,12 +1074,13 @@ class ConfigurableTask(Task):
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match
=
int
(
is_greedy
[
gold
])
if
gold
!=
-
100
else
0
exact_match
=
int
(
is_greedy
[
gold
])
if
gold
!=
-
100
else
0
# gold, lls, is_greedy, completion_len
result_dict
=
{
result_dict
=
{
**
({
"acc"
:
acc
}
if
"acc"
in
use_metric
else
{}),
**
({
"acc"
:
acc
}
if
"acc"
in
use_metric
else
{}),
**
({
"f1"
:
(
gold
,
pred
)}
if
"f1"
in
use_metric
else
{}),
**
({
"mcc"
:
(
gold
,
pred
)}
if
"mcc"
in
use_metric
else
{}),
**
({
"acc_norm"
:
acc_norm
}
if
"acc_norm"
in
use_metric
else
{}),
**
({
"acc_norm"
:
acc_norm
}
if
"acc_norm"
in
use_metric
else
{}),
**
({
"exact_match"
:
exact_match
}
if
"exact_match"
in
use_metric
else
{}),
**
({
"exact_match"
:
exact_match
}
if
"exact_match"
in
use_metric
else
{}),
**
({
"f1"
:
(
gold
,
pred
)}
if
"f1"
in
use_metric
else
{}),
**
({
"mcc"
:
(
gold
,
pred
)}
if
"mcc"
in
use_metric
else
{}),
}
}
if
"acc_mutual_info"
in
use_metric
:
if
"acc_mutual_info"
in
use_metric
:
...
@@ -1160,9 +1143,7 @@ class ConfigurableTask(Task):
...
@@ -1160,9 +1143,7 @@ class ConfigurableTask(Task):
predictions
=
[
result
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
metric
],
**
self
.
_metric_fn_kwargs
[
metric
],
)
)
except
(
except
TypeError
:
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
TypeError
):
# needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score
=
self
.
_metric_fn_list
[
metric
]([
gold
,
result
])
result_score
=
self
.
_metric_fn_list
[
metric
]([
gold
,
result
])
if
isinstance
(
result_score
,
dict
):
if
isinstance
(
result_score
,
dict
):
# TODO: this handles the case where HF evaluate returns a dict.
# TODO: this handles the case where HF evaluate returns a dict.
...
@@ -1176,8 +1157,7 @@ class ConfigurableTask(Task):
...
@@ -1176,8 +1157,7 @@ class ConfigurableTask(Task):
return
result_dict
return
result_dict
def
aggregation
(
self
):
def
compute_metric
(
self
):
# return self._aggregation_list
return
self
.
_metric_fn_list
return
self
.
_metric_fn_list
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
...
@@ -1224,7 +1204,7 @@ class MultipleChoiceTask(Task):
...
@@ -1224,7 +1204,7 @@ class MultipleChoiceTask(Task):
"acc_norm"
:
True
,
"acc_norm"
:
True
,
}
}
def
aggregation
(
self
)
->
dict
:
def
compute_metric
(
self
)
->
dict
:
return
{
return
{
"acc"
:
mean
,
"acc"
:
mean
,
"acc_norm"
:
mean
,
"acc_norm"
:
mean
,
...
@@ -1285,7 +1265,7 @@ class PerplexityTask(Task):
...
@@ -1285,7 +1265,7 @@ class PerplexityTask(Task):
"bits_per_byte"
:
(
loglikelihood
,
bytes_
),
"bits_per_byte"
:
(
loglikelihood
,
bytes_
),
}
}
def
aggregation
(
self
)
->
dict
:
def
compute_metric
(
self
)
->
dict
:
return
{
return
{
"word_perplexity"
:
weighted_perplexity
,
"word_perplexity"
:
weighted_perplexity
,
"byte_perplexity"
:
weighted_perplexity
,
"byte_perplexity"
:
weighted_perplexity
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment