Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0a9ad6ee
Commit
0a9ad6ee
authored
Jun 06, 2023
by
lintangsutawika
Browse files
much better way to process all metrics chosen
parent
5693abc5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
82 additions
and
69 deletions
+82
-69
lm_eval/api/task.py
lm_eval/api/task.py
+82
-69
No files found.
lm_eval/api/task.py
View file @
0a9ad6ee
...
@@ -34,6 +34,13 @@ from lm_eval.logger import eval_logger
...
@@ -34,6 +34,13 @@ from lm_eval.logger import eval_logger
from
lm_eval.prompts
import
get_prompt
from
lm_eval.prompts
import
get_prompt
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
ALL_OUTPUT_TYPES
=
[
"loglikelihood"
,
"multiple_choice"
,
"loglikelihood_rolling"
,
"greedy_until"
,
]
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
...
@@ -80,12 +87,12 @@ class TaskConfig(dict):
...
@@ -80,12 +87,12 @@ class TaskConfig(dict):
# allow user-specified aliases so that users can
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# force prompt-compatibility for some prompt regardless of
# field names in prompt
# field names in prompt
#
if self.template_aliases is not None:
if
self
.
template_aliases
is
not
None
:
#
if type(self.doc_to_text) == str:
if
type
(
self
.
doc_to_text
)
==
str
:
#
self.doc_to_text = self.template_aliases + self.doc_to_text
self
.
doc_to_text
=
self
.
template_aliases
+
self
.
doc_to_text
#
if type(self.doc_to_target) == str:
if
type
(
self
.
doc_to_target
)
==
str
:
#
self.doc_to_target = self.template_aliases + self.doc_to_target
self
.
doc_to_target
=
self
.
template_aliases
+
self
.
doc_to_target
# set "task_name" metadata field based on the "primary" name set
# set "task_name" metadata field based on the "primary" name set
if
self
.
names
:
if
self
.
names
:
...
@@ -472,6 +479,7 @@ class ConfigurableTask(Task):
...
@@ -472,6 +479,7 @@ class ConfigurableTask(Task):
)
)
if
self
.
_config
.
output_type
is
not
None
:
if
self
.
_config
.
output_type
is
not
None
:
assert
self
.
_config
.
output_type
in
ALL_OUTPUT_TYPES
self
.
OUTPUT_TYPE
=
self
.
_config
.
output_type
self
.
OUTPUT_TYPE
=
self
.
_config
.
output_type
if
self
.
_config
.
dataset_path
is
not
None
:
if
self
.
_config
.
dataset_path
is
not
None
:
...
@@ -480,68 +488,71 @@ class ConfigurableTask(Task):
...
@@ -480,68 +488,71 @@ class ConfigurableTask(Task):
if
self
.
_config
.
dataset_name
is
not
None
:
if
self
.
_config
.
dataset_name
is
not
None
:
self
.
DATASET_NAME
=
self
.
_config
.
dataset_name
self
.
DATASET_NAME
=
self
.
_config
.
dataset_name
if
self
.
_config
.
metric_list
is
not
None
:
self
.
_metric_fn_list
=
{}
self
.
_metric_list
=
{}
self
.
_metric_fn_kwargs
=
{}
self
.
_metric_kwargs
=
{}
self
.
_aggregation_list
=
{}
self
.
_aggregation_list
=
{}
self
.
_higher_is_better
=
{}
self
.
_higher_is_better
=
{}
if
self
.
_config
.
output_type
==
"greedy_until"
:
for
metric_config
in
self
.
_config
.
metric_list
:
metric_name
=
metric_config
[
"metric"
]
aggregation
=
metric_config
[
"aggregation"
]
higher_is_better
=
metric_config
[
"higher_is_better"
]
kwargs
=
{
key
:
metric_config
[
key
]
for
key
in
metric_config
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
]
}
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
aggregation
]
if
metric_name
in
METRIC_REGISTRY
.
keys
():
if
self
.
_config
.
metric_list
is
None
:
self
.
_metric_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
eval_logger
.
warning
(
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
f
"Output Type set as
{
self
.
_config
.
output_type
}
and metric_list is not set"
metric_name
"Will default to exact_match"
]
)
else
:
_metric_list
=
METRIC_REGISTRY
[
self
.
_config
.
output_type
]
self
.
_higher_is_better
[
metric_name
]
=
higher_is_better
for
metric_name
,
metric_params
in
_metric_list
.
items
():
try
:
self
.
_metric_fn_list
[
metric_name
]
=
metric_params
[
"fn"
]
metric_object
=
evaluate
.
load
(
metric_name
)
self
.
_aggregation_list
[
metric_name
]
=
metric_params
[
"aggregation"
]
self
.
_metric_list
[
metric_name
]
=
metric_object
self
.
_higher_is_better
[
metric_name
]
=
metric_params
[
"higher_is_better"
]
self
.
_metric_kwargs
[
metric_name
]
=
kwargs
else
:
for
metric_config
in
self
.
_config
.
metric_list
:
except
Exception
:
raise
Warning
(
assert
"metric"
in
metric_config
"{} not found in the evaluate library!"
.
format
(
metric_name
=
metric_config
[
"metric"
]
metric_name
kwargs
=
{
),
key
:
metric_config
[
key
]
"Please check https://huggingface.co/evaluate-metric"
,
for
key
in
metric_config
)
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
]
else
:
}
eval_logger
.
warning
(
if
metric_name
in
_metric_list
:
f
"Output Type set as
{
self
.
_config
.
output_type
}
which does not use metric_list"
self
.
_metric_fn_list
[
metric_name
]
=
metric_params
[
"fn"
]
"metric list will be unused."
else
:
)
eval_logger
.
warning
(
f
"Metric
{
metric_name
}
not found, "
"Searching from https://huggingface.co/evaluate-metric"
)
try
:
metric_object
=
evaluate
.
load
(
metric_name
)
self
.
_metric_fn_list
[
metric_name
]
=
metric_object
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
except
Exception
:
raise
Warning
(
"{} not found in the evaluate library!"
.
format
(
metric_name
),
"Please check https://huggingface.co/evaluate-metric"
,
)
if
self
.
_config
.
output_type
==
"loglikelihood"
:
if
"aggregation"
in
metric_config
:
metric_list
=
[
"perplexity"
,
"acc"
]
self
.
_aggregation_list
[
metric_name
]
=
metric_config
[
"aggregation"
]
elif
self
.
_config
.
output_type
==
"loglikelihood_rolling"
:
else
:
metric_list
=
[
eval_logger
.
warning
(
"word_perplexity"
,
f
"metric
{
metric_name
}
is defined, but aggregation is not"
"byte_perplexity"
,
f
"using default aggregation for
{
metric_name
}
"
"bits_per_byte"
,
)
self
.
_aggregation_list
[
metric_name
]
=
_metric_list
[
metric_name
][
"aggregation"
]
]
elif
self
.
_config
.
output_type
==
"multiple_choice"
:
metric_list
=
[
"acc"
,
"acc_norm"
]
for
metric_name
in
metric_list
:
if
"higher_is_better"
in
metric_config
:
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
"mean"
]
self
.
_higher_is_better
[
metric_name
]
=
metric_config
[
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
"higher_is_better"
metric_name
]
else
:
eval_logger
.
warning
(
f
"metric
{
metric_name
}
is defined, but higher_is_better is not"
f
"using default higher_is_better for
{
metric_name
}
"
)
self
.
_higher_is_better
[
metric_name
]
=
_metric_list
[
metric_name
][
"higher_is_better"
]
]
self
.
download
(
self
.
_config
.
dataset_kwargs
)
self
.
download
(
self
.
_config
.
dataset_kwargs
)
...
@@ -743,18 +754,19 @@ class ConfigurableTask(Task):
...
@@ -743,18 +754,19 @@ class ConfigurableTask(Task):
result_dict
=
{
"perplexity"
:
ll
,
"acc"
:
int
(
is_greedy
)}
result_dict
=
{
"perplexity"
:
ll
,
"acc"
:
int
(
is_greedy
)}
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
(
loglikelihood
,)
=
results
(
loglikelihood
,)
=
results
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
_
words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
bytes
_
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
_
bytes
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
return
{
return
{
"word_perplexity"
:
(
loglikelihood
,
words
),
"word_perplexity"
:
(
loglikelihood
,
_
words
),
"byte_perplexity"
:
(
loglikelihood
,
bytes
_
),
"byte_perplexity"
:
(
loglikelihood
,
_
bytes
),
"bits_per_byte"
:
(
loglikelihood
,
bytes
_
),
"bits_per_byte"
:
(
loglikelihood
,
_
bytes
),
}
}
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
lls
=
[
lls
=
[
res
[
0
]
for
res
in
results
res
[
0
]
for
res
in
results
]
# only retain loglikelihoods, discard is_greedy
]
# only retain loglikelihoods, discard is_greedy
gold
=
int
(
self
.
doc_to_target
(
doc
))
gold
=
int
(
self
.
doc_to_target
(
doc
))
pred
=
np
.
argmax
(
lls
)
# retrieve choices in List[str] form, to compute choice lengths, etc.
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices
=
ast
.
literal_eval
(
choices
=
ast
.
literal_eval
(
utils
.
apply_template
(
utils
.
apply_template
(
...
@@ -778,6 +790,7 @@ class ConfigurableTask(Task):
...
@@ -778,6 +790,7 @@ class ConfigurableTask(Task):
result_dict
=
{
result_dict
=
{
"acc"
:
acc
,
"acc"
:
acc
,
"f1"
:
(
pred
,
gold
),
"acc_norm"
:
acc_norm
,
"acc_norm"
:
acc_norm
,
}
}
...
@@ -814,7 +827,7 @@ class ConfigurableTask(Task):
...
@@ -814,7 +827,7 @@ class ConfigurableTask(Task):
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Passed invalid output_type '
{
self
.
OUTPUT_TYPE
}
' ! Please use one of "
,
f
"Passed invalid output_type '
{
self
.
OUTPUT_TYPE
}
' ! Please use one of "
,
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'"
,
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'
, or 'multiple_choice'
"
,
)
)
return
result_dict
return
result_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment