Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
cea713dc
Commit
cea713dc
authored
Apr 28, 2023
by
lintangsutawika
Browse files
can process looglikelihood requests
parent
84191b83
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
19 deletions
+41
-19
lm_eval/api/task.py
lm_eval/api/task.py
+41
-19
No files found.
lm_eval/api/task.py
View file @
cea713dc
...
@@ -11,6 +11,7 @@ import numpy as np
...
@@ -11,6 +11,7 @@ import numpy as np
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
lm_eval.api
import
METRIC_REGISTRY
,
AGGREGATION_REGISTRY
,
HIGHER_IS_BETTER_REGISTRY
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
get_metric
,
get_aggregation
,
mean
,
weighted_perplexity
,
bits_per_byte
from
lm_eval.api.metrics
import
get_metric
,
get_aggregation
,
mean
,
weighted_perplexity
,
bits_per_byte
from
lm_eval
import
utils
from
lm_eval
import
utils
...
@@ -45,6 +46,8 @@ class TaskConfig(dict):
...
@@ -45,6 +46,8 @@ class TaskConfig(dict):
filters
:
str
=
None
#TODO: need to make this typehint `list`?
filters
:
str
=
None
#TODO: need to make this typehint `list`?
normalization
:
str
=
None
# TODO: add length-normalization of various types, mutual info
normalization
:
str
=
None
# TODO: add length-normalization of various types, mutual info
stop_sequences
:
list
=
None
# TODO: allow passing of stop sequences to greedy gen.
stop_sequences
:
list
=
None
# TODO: allow passing of stop sequences to greedy gen.
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
str
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# allow user-specified aliases so that users can
# allow user-specified aliases so that users can
...
@@ -379,6 +382,10 @@ class ConfigurableTask(Task):
...
@@ -379,6 +382,10 @@ class ConfigurableTask(Task):
):
):
self
.
_config
=
TaskConfig
(
**
config
)
self
.
_config
=
TaskConfig
(
**
config
)
if
self
.
_config
.
output_type
is
not
None
:
self
.
OUTPUT_TYPE
=
self
.
_config
.
output_type
if
self
.
_config
.
dataset_path
is
not
None
:
if
self
.
_config
.
dataset_path
is
not
None
:
self
.
DATASET_PATH
=
self
.
_config
.
dataset_path
self
.
DATASET_PATH
=
self
.
_config
.
dataset_path
...
@@ -392,17 +399,18 @@ class ConfigurableTask(Task):
...
@@ -392,17 +399,18 @@ class ConfigurableTask(Task):
self
.
_metric_kwargs
=
{}
self
.
_metric_kwargs
=
{}
for
metric_config
in
self
.
_config
.
metric_list
:
for
metric_config
in
self
.
_config
.
metric_list
:
metric_name
=
metric_config
[
'
na
me'
]
metric_name
=
metric_config
[
'me
tric
'
]
aggregation
=
metric_config
[
'aggregation'
]
aggregation
=
metric_config
[
'aggregation'
]
higher_is_better
=
metric_config
[
'higher_is_better'
]
higher_is_better
=
metric_config
[
'higher_is_better'
]
kwargs
=
{
key
:
metric_config
[
key
]
for
key
in
metric_config
if
key
not
in
[
'name'
,
'aggregation'
,
'higher_is_better'
]}
kwargs
=
{
key
:
metric_config
[
key
]
for
key
in
metric_config
if
key
not
in
[
'name'
,
'aggregation'
,
'higher_is_better'
]}
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
aggregation
]
self
.
_aggregation_list
[
metric_name
]
=
AGGREGATION_REGISTRY
[
aggregation
]
self
.
_higher_is_better
[
metric_name
]
=
higher_is_better
if
metric_name
in
METRIC_REGISTRY
.
keys
():
if
metric_name
in
METRIC_REGISTRY
.
keys
():
self
.
_metric_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
self
.
_metric_list
[
metric_name
]
=
METRIC_REGISTRY
[
metric_name
]
self
.
_higher_is_better
[
metric_name
]
=
HIGHER_IS_BETTER_REGISTRY
[
metric_name
]
else
:
else
:
self
.
_higher_is_better
[
metric_name
]
=
higher_is_better
try
:
try
:
metric_object
=
evaluate
.
load
(
metric_name
)
metric_object
=
evaluate
.
load
(
metric_name
)
self
.
_metric_list
[
metric_name
]
=
metric_object
self
.
_metric_list
[
metric_name
]
=
metric_object
...
@@ -454,6 +462,13 @@ class ConfigurableTask(Task):
...
@@ -454,6 +462,13 @@ class ConfigurableTask(Task):
if
self
.
_config
.
test_split
is
not
None
:
if
self
.
_config
.
test_split
is
not
None
:
return
self
.
dataset
[
self
.
_config
.
test_split
]
return
self
.
dataset
[
self
.
_config
.
test_split
]
def
should_decontaminate
(
self
):
return
self
.
_config
.
should_decontaminate
def
doc_to_decontamination_query
(
self
,
doc
):
if
self
.
_config
.
should_decontaminate
:
return
utils
.
apply_template
(
self
.
_config
.
doc_to_decontamination_query
,
doc
)
def
_process_doc
(
self
,
doc
):
def
_process_doc
(
self
,
doc
):
"""
"""
Override this to process (detokenize, strip, replace, etc.) individual
Override this to process (detokenize, strip, replace, etc.) individual
...
@@ -473,15 +488,15 @@ class ConfigurableTask(Task):
...
@@ -473,15 +488,15 @@ class ConfigurableTask(Task):
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
def
construct_requests
(
self
,
doc
,
ctx
,
**
kwargs
):
if
self
.
output_type
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
elif
self
.
output_type
==
"loglikelihood_rolling"
:
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
arguments
=
(
self
.
doc_to_target
(
doc
),)
arguments
=
(
self
.
doc_to_target
(
doc
),)
elif
self
.
output_type
==
"greedy_until"
:
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
arguments
=
(
ctx
,
"
\n\n
"
)
arguments
=
(
ctx
,
"
\n\n
"
)
return
Instance
(
return
Instance
(
request_type
=
self
.
output_type
,
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
doc
=
doc
,
arguments
=
arguments
,
arguments
=
arguments
,
**
kwargs
**
kwargs
...
@@ -489,28 +504,35 @@ class ConfigurableTask(Task):
...
@@ -489,28 +504,35 @@ class ConfigurableTask(Task):
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
if
self
.
_config
.
gold_alias
is
not
None
:
gold
=
doc
[
self
.
_config
.
gold_alias
]
else
:
gold
=
self
.
doc_to_target
(
doc
)
result_dict
=
{}
result_dict
=
{}
for
key
,
result
in
zip
(
self
.
_metric_list
.
keys
(),
results
):
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
_dict
=
self
.
_metric_list
[
key
](
results
=
results
[
0
]
references
=
[
gold
],
ll
,
is_greedy
=
results
predictions
=
[
result
],
result_dict
=
{
"perplexity"
:
ll
,
"accuracy"
:
int
(
is_greedy
)}
)
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
pass
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
if
self
.
_config
.
gold_alias
is
not
None
:
gold
=
doc
[
self
.
_config
.
gold_alias
]
else
:
gold
=
self
.
doc_to_target
(
doc
)
for
key
,
result
in
zip
(
self
.
_metric_list
.
keys
(),
results
):
_dict
=
self
.
_metric_list
[
key
].
compute
(
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_kwargs
[
key
]
)
result_dict
[
key
]
=
_dict
[
key
]
result_dict
[
key
]
=
_dict
[
key
]
return
result_dict
return
result_dict
def
aggregation
(
self
):
def
aggregation
(
self
):
return
self
.
_aggregation_list
return
self
.
_aggregation_list
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
return
self
.
_higher_is_better
return
self
.
_higher_is_better
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment