Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c746d1fb
Commit
c746d1fb
authored
Jun 05, 2023
by
lintangsutawika
Browse files
fixing metric for each output_type
parent
a339ffd8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
24 deletions
+52
-24
lm_eval/api/task.py
lm_eval/api/task.py
+52
-24
No files found.
lm_eval/api/task.py
View file @
c746d1fb
...
@@ -485,6 +485,30 @@ class ConfigurableTask(Task):
...
@@ -485,6 +485,30 @@ class ConfigurableTask(Task):
self
.
_metric_kwargs
=
{}
self
.
_metric_kwargs
=
{}
self
.
_aggregation_list
=
{}
self
.
_aggregation_list
=
{}
self
.
_higher_is_better
=
{}
self
.
_higher_is_better
=
{}
if
self
.
_config
.
output_type
!=
"greedy_util"
:
eval_logger
.
warning
(
f
"Output Type set as
{
self
.
_config
.
output_type
}
which does not use metric_list"
"metric list will be unused."
)
if
self
.
_config
.
output_type
==
"loglikelihood"
:
metric_list
=
[
"perplexity"
,
"acc"
]
elif
self
.
_config
.
output_type
==
"loglikelihood_rolling"
:
metric_list
=
[
"word_perplexity"
,
"byte_perplexity"
,
"bits_per_byte"
,
]
elif
self
.
_config
.
output_type
==
"multiple_choice"
:
metric_list
=
[
"acc"
,
"acc_norm"
]
for
metric_name
in
metric_list
:
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
"mean"
]
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
else
:
for
metric_config
in
self
.
_config
.
metric_list
:
for
metric_config
in
self
.
_config
.
metric_list
:
metric_name
=
metric_config
[
"metric"
]
metric_name
=
metric_config
[
"metric"
]
...
@@ -496,7 +520,9 @@ class ConfigurableTask(Task):
...
@@ -496,7 +520,9 @@ class ConfigurableTask(Task):
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
]
if
key
not
in
[
"metric"
,
"aggregation"
,
"higher_is_better"
]
}
}
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
aggregation
]
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
aggregation
]
if
metric_name
in
METRIC_REGISTRY
.
keys
():
if
metric_name
in
METRIC_REGISTRY
.
keys
():
self
.
_metric_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
self
.
_metric_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
...
@@ -512,7 +538,9 @@ class ConfigurableTask(Task):
...
@@ -512,7 +538,9 @@ class ConfigurableTask(Task):
except
Exception
:
except
Exception
:
raise
Warning
(
raise
Warning
(
"{} not found in the evaluate library!"
.
format
(
metric_name
),
"{} not found in the evaluate library!"
.
format
(
metric_name
),
"Please check https://huggingface.co/evaluate-metric"
,
"Please check https://huggingface.co/evaluate-metric"
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment