Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
028f04c7
Commit
028f04c7
authored
Dec 19, 2023
by
lintangsutawika
Browse files
loglikelihood and loglikelihood rolling modified
parent
1d262a59
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
56 deletions
+42
-56
lm_eval/api/task.py
lm_eval/api/task.py
+42
-56
No files found.
lm_eval/api/task.py
View file @
028f04c7
...
@@ -566,11 +566,16 @@ class ConfigurableTask(Task):
...
@@ -566,11 +566,16 @@ class ConfigurableTask(Task):
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
config
.
output_type
]
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
config
.
output_type
]
for
metric_name
in
_metric_list
:
for
metric_name
in
_metric_list
:
self
.
_metric_fn_list
[
metric_name
]
=
get_metric
(
metric_name
)
metric
=
get_metric
(
metric_name
)
self
.
_metric_fn_list
[
metric_name
]
=
metric
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
self
.
_metric_fn_kwargs
[
metric_name
]
=
{}
self
.
_aggregation_list
[
metric_name
]
=
get_metric_aggregation
(
self
.
_aggregation_list
[
metric_name
]
=
metric
.
aggregation
metric_name
# try:
)
# self._aggregation_list[metric_name] = metric.aggregation
# except:
# self._aggregation_list[metric_name] = get_metric_aggregation(
# metric_name
# )
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
else
:
else
:
for
metric_config
in
self
.
config
.
metric_list
:
for
metric_config
in
self
.
config
.
metric_list
:
...
@@ -601,35 +606,35 @@ class ConfigurableTask(Task):
...
@@ -601,35 +606,35 @@ class ConfigurableTask(Task):
)
)
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
self
.
_metric_fn_kwargs
[
metric_name
]
=
kwargs
if
"aggregation"
in
metric_config
:
#
if "aggregation" in metric_config:
agg_name
=
metric_config
[
"aggregation"
]
#
agg_name = metric_config["aggregation"]
if
type
(
agg_name
)
==
str
:
#
if type(agg_name) == str:
self
.
_aggregation_list
[
metric_name
]
=
get_aggregation
(
agg_name
)
#
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif
callable
(
agg_name
):
#
elif callable(agg_name):
self
.
_aggregation_list
[
metric_name
]
=
metric_config
[
#
self._aggregation_list[metric_name] = metric_config[
"aggregation"
#
"aggregation"
]
#
]
else
:
#
else:
INV_AGG_REGISTRY
=
{
v
:
k
for
k
,
v
in
AGGREGATION_REGISTRY
.
items
()}
#
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg
=
get_metric_aggregation
(
metric_name
)
#
metric_agg = get_metric_aggregation(metric_name)
eval_logger
.
warning
(
#
eval_logger.warning(
f
"[Task:
{
self
.
_config
.
task
}
] metric
{
metric_name
}
is defined, but aggregation is not. "
#
f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
f
"using default "
#
f"using default "
f
"aggregation=
{
INV_AGG_REGISTRY
[
metric_agg
]
}
"
#
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
#
)
self
.
_aggregation_list
[
metric_name
]
=
metric_agg
#
self._aggregation_list[metric_name] = metric_agg
if
"higher_is_better"
in
metric_config
:
#
if "higher_is_better" in metric_config:
self
.
_higher_is_better
[
metric_name
]
=
metric_config
[
#
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
#
"higher_is_better"
]
#
]
else
:
#
else:
eval_logger
.
warning
(
#
eval_logger.warning(
f
"[Task:
{
self
.
_config
.
task
}
] metric
{
metric_name
}
is defined, but higher_is_better is not. "
#
f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f
"using default "
#
f"using default "
f
"higher_is_better=
{
is_higher_better
(
metric_name
)
}
"
#
f"higher_is_better={is_higher_better(metric_name)}"
)
#
)
self
.
_higher_is_better
[
metric_name
]
=
is_higher_better
(
metric_name
)
#
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
self
.
_training_docs
=
None
...
@@ -1022,35 +1027,15 @@ class ConfigurableTask(Task):
...
@@ -1022,35 +1027,15 @@ class ConfigurableTask(Task):
return
self
.
config
.
process_results
(
doc
,
results
)
return
self
.
config
.
process_results
(
doc
,
results
)
result_dict
=
{}
result_dict
=
{}
use_metric
=
list
(
self
.
_metric_fn_list
.
keys
())
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
results
=
results
[
0
]
results
=
results
[
0
]
ll
,
is_greedy
=
results
ll
,
is_greedy
=
results
return
{
return
ll
,
is_greedy
**
({
"perplexity"
:
ll
}
if
"perplexity"
in
use_metric
else
{}),
**
({
"acc"
:
int
(
is_greedy
)}
if
"acc"
in
use_metric
else
{}),
}
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
(
loglikelihood
,)
=
results
(
loglikelihood
,)
=
results
_words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
_words
=
self
.
count_words
(
self
.
doc_to_target
(
doc
))
_bytes
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
_bytes
=
self
.
count_bytes
(
self
.
doc_to_target
(
doc
))
return
{
return
loglikelihood
,
_words
,
_bytes
**
(
{
"word_perplexity"
:
(
loglikelihood
,
_words
)}
if
"word_perplexity"
in
use_metric
else
{}
),
**
(
{
"byte_perplexity"
:
(
loglikelihood
,
_bytes
)}
if
"byte_perplexity"
in
use_metric
else
{}
),
**
(
{
"bits_per_byte"
:
(
loglikelihood
,
_bytes
)}
if
"bits_per_byte"
in
use_metric
else
{}
),
}
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
lls
,
is_greedy
=
zip
(
*
results
)
lls
,
is_greedy
=
zip
(
*
results
)
...
@@ -1192,7 +1177,8 @@ class ConfigurableTask(Task):
...
@@ -1192,7 +1177,8 @@ class ConfigurableTask(Task):
return
result_dict
return
result_dict
def
aggregation
(
self
):
def
aggregation
(
self
):
return
self
.
_aggregation_list
# return self._aggregation_list
return
self
.
_metric_fn_list
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
self
.
_higher_is_better
return
self
.
_higher_is_better
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment